import gradio as gr import requests from PIL import Image import io import os import subprocess import sys import atexit import time API_KEY = os.environ.get("SILICONFLOW_API_KEY") if not API_KEY: raise ValueError("请设置SILICONFLOW_API_KEY环境变量") API_URL = "https://api.siliconflow.cn/v1/image/generations" TEXT_API_URL = "https://api.siliconflow.cn/v1/chat/completions" HEADERS = { "accept": "application/json", "content-type": "application/json", "Authorization": f"Bearer {API_KEY}" } # Cloudflared 安装和设置 def install_cloudflared(): if sys.platform.startswith('linux'): subprocess.run(["curl", "-L", "--output", "cloudflared", "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64"]) subprocess.run(["chmod", "+x", "cloudflared"]) elif sys.platform == 'darwin': subprocess.run(["brew", "install", "cloudflare/cloudflare/cloudflared"]) elif sys.platform == 'win32': subprocess.run(["powershell", "-Command", "Invoke-WebRequest -Uri https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-windows-amd64.exe -OutFile cloudflared.exe"]) else: raise OSError("Unsupported operating system") cloudflared_process = None def start_cloudflared(port): global cloudflared_process if sys.platform.startswith('linux') or sys.platform == 'darwin': cloudflared_process = subprocess.Popen(["./cloudflared", "tunnel", "--url", f"http://localhost:{port}"]) elif sys.platform == 'win32': cloudflared_process = subprocess.Popen(["cloudflared.exe", "tunnel", "--url", f"http://localhost:{port}"]) atexit.register(stop_cloudflared) def stop_cloudflared(): global cloudflared_process if cloudflared_process: cloudflared_process.terminate() cloudflared_process.wait() # 其余函数保持不变 def generate_image(model, prompt, image_size, batch_size=1, num_inference_steps=20): data = { "model": model, "prompt": prompt, "image_size": image_size, "batch_size": batch_size, "num_inference_steps": num_inference_steps } print(f"Sending request to API with data: {data}") response = requests.post(API_URL, headers=HEADERS, json=data) print(f"API response status code: {response.status_code}") if response.status_code == 200: response_json = response.json() print(f"API response: {response_json}") if "images" in response_json: images = [] for img_data in response_json["images"]: image_url = img_data["url"] try: image_response = requests.get(image_url) image = Image.open(io.BytesIO(image_response.content)) images.append(image) except Exception as e: print(f"Error fetching image: {e}") return images if images else "Error: No images could be fetched" else: return "No image data in response" else: return f"Error: {response.status_code}, {response.text}" def use_gemma_model(prompt, task): data = { "model": "google/gemma-2-9b-it", "messages": [ {"role": "system", "content": f"You are an AI assistant that helps with {task}. Respond concisely."}, {"role": "user", "content": prompt} ], "temperature": 0.7, "max_tokens": 150 } response = requests.post(TEXT_API_URL, headers=HEADERS, json=data) if response.status_code == 200: return response.json()['choices'][0]['message']['content'].strip() else: return f"Error: {response.status_code}, {response.text}" def enhance_prompt(prompt): return use_gemma_model(f"Enhance this image prompt for better results: {prompt}", "prompt enhancement") def translate_prompt(prompt): return use_gemma_model(f"Translate this text to English: {prompt}", "translation") def generate_with_options(model, prompt, image_width, image_height, batch_size, num_inference_steps, enhance, translate): try: if enhance: prompt = enhance_prompt(prompt) if translate: prompt = translate_prompt(prompt) image_size = f"{image_width}x{image_height}" result = generate_image(model, prompt, image_size, batch_size, num_inference_steps) if isinstance(result, str): # 如果返回的是错误消息 return None, result # 返回 None 作为图像,错误消息作为文本 if isinstance(result, list): # 如果返回的是图像列表 return result, None return result, None # 返回图像和 None 作为错误消息 except Exception as e: return None, f"An error occurred: {str(e)}" def create_model_interface(model_name, default_steps=20): with gr.Tab(model_name): with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt") enhance = gr.Checkbox(label="Enhance Prompt") translate = gr.Checkbox(label="Translate Prompt") image_width = gr.Number(label="Image Width", value=1024, step=64) image_height = gr.Number(label="Image Height", value=1024, step=64) batch_size = gr.Slider(minimum=1, maximum=4, step=1, label="Batch Size", value=1) num_inference_steps = gr.Slider(minimum=1, maximum=50, step=1, label="Inference Steps", value=default_steps) generate_button = gr.Button("Generate") with gr.Column(): output = gr.Gallery(label="Generated Images") error_output = gr.Textbox(label="Error Message", visible=False) generate_button.click( fn=generate_with_options, inputs=[ gr.Textbox(value=model_name, visible=False), prompt, image_width, image_height, batch_size, num_inference_steps, enhance, translate ], outputs=[output, error_output] ) with gr.Blocks() as demo: gr.Markdown("# Image Generation with FLUX Models") create_model_interface("black-forest-labs/FLUX.1-dev", default_steps=20) create_model_interface("black-forest-labs/FLUX.1-schnell") create_model_interface("Pro/black-forest-labs/FLUX.1-schnell") if __name__ == "__main__": port = 7860 # 或者您想使用的其他端口 print("Installing Cloudflared...") install_cloudflared() print("Starting Cloudflared...") start_cloudflared(port) print("Cloudflared started. Launching Gradio app...") demo.launch(server_name="0.0.0.0", server_port=port)