Spaces:
Building
Building
import gradio as gr | |
import requests | |
from PIL import Image | |
import io | |
import os | |
import subprocess | |
import sys | |
import atexit | |
import time | |
API_KEY = os.environ.get("SILICONFLOW_API_KEY") | |
if not API_KEY: | |
raise ValueError("请设置SILICONFLOW_API_KEY环境变量") | |
API_URL = "https://api.siliconflow.cn/v1/image/generations" | |
TEXT_API_URL = "https://api.siliconflow.cn/v1/chat/completions" | |
HEADERS = { | |
"accept": "application/json", | |
"content-type": "application/json", | |
"Authorization": f"Bearer {API_KEY}" | |
} | |
# Cloudflared 安装和设置 | |
def install_cloudflared(): | |
if sys.platform.startswith('linux'): | |
subprocess.run(["curl", "-L", "--output", "cloudflared", "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64"]) | |
subprocess.run(["chmod", "+x", "cloudflared"]) | |
elif sys.platform == 'darwin': | |
subprocess.run(["brew", "install", "cloudflare/cloudflare/cloudflared"]) | |
elif sys.platform == 'win32': | |
subprocess.run(["powershell", "-Command", "Invoke-WebRequest -Uri https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-windows-amd64.exe -OutFile cloudflared.exe"]) | |
else: | |
raise OSError("Unsupported operating system") | |
cloudflared_process = None | |
def start_cloudflared(port): | |
global cloudflared_process | |
if sys.platform.startswith('linux') or sys.platform == 'darwin': | |
cloudflared_process = subprocess.Popen(["./cloudflared", "tunnel", "--url", f"http://localhost:{port}"]) | |
elif sys.platform == 'win32': | |
cloudflared_process = subprocess.Popen(["cloudflared.exe", "tunnel", "--url", f"http://localhost:{port}"]) | |
atexit.register(stop_cloudflared) | |
def stop_cloudflared(): | |
global cloudflared_process | |
if cloudflared_process: | |
cloudflared_process.terminate() | |
cloudflared_process.wait() | |
# 其余函数保持不变 | |
def generate_image(model, prompt, image_size, batch_size=1, num_inference_steps=20): | |
data = { | |
"model": model, | |
"prompt": prompt, | |
"image_size": image_size, | |
"batch_size": batch_size, | |
"num_inference_steps": num_inference_steps | |
} | |
print(f"Sending request to API with data: {data}") | |
response = requests.post(API_URL, headers=HEADERS, json=data) | |
print(f"API response status code: {response.status_code}") | |
if response.status_code == 200: | |
response_json = response.json() | |
print(f"API response: {response_json}") | |
if "images" in response_json: | |
images = [] | |
for img_data in response_json["images"]: | |
image_url = img_data["url"] | |
try: | |
image_response = requests.get(image_url) | |
image = Image.open(io.BytesIO(image_response.content)) | |
images.append(image) | |
except Exception as e: | |
print(f"Error fetching image: {e}") | |
return images if images else "Error: No images could be fetched" | |
else: | |
return "No image data in response" | |
else: | |
return f"Error: {response.status_code}, {response.text}" | |
def use_gemma_model(prompt, task): | |
data = { | |
"model": "google/gemma-2-9b-it", | |
"messages": [ | |
{"role": "system", "content": f"You are an AI assistant that helps with {task}. Respond concisely."}, | |
{"role": "user", "content": prompt} | |
], | |
"temperature": 0.7, | |
"max_tokens": 150 | |
} | |
response = requests.post(TEXT_API_URL, headers=HEADERS, json=data) | |
if response.status_code == 200: | |
return response.json()['choices'][0]['message']['content'].strip() | |
else: | |
return f"Error: {response.status_code}, {response.text}" | |
def enhance_prompt(prompt): | |
return use_gemma_model(f"Enhance this image prompt for better results: {prompt}", "prompt enhancement") | |
def translate_prompt(prompt): | |
return use_gemma_model(f"Translate this text to English: {prompt}", "translation") | |
def generate_with_options(model, prompt, image_width, image_height, batch_size, num_inference_steps, enhance, translate): | |
try: | |
if enhance: | |
prompt = enhance_prompt(prompt) | |
if translate: | |
prompt = translate_prompt(prompt) | |
image_size = f"{image_width}x{image_height}" | |
result = generate_image(model, prompt, image_size, batch_size, num_inference_steps) | |
if isinstance(result, str): # 如果返回的是错误消息 | |
return None, result # 返回 None 作为图像,错误消息作为文本 | |
if isinstance(result, list): # 如果返回的是图像列表 | |
return result, None | |
return result, None # 返回图像和 None 作为错误消息 | |
except Exception as e: | |
return None, f"An error occurred: {str(e)}" | |
def create_model_interface(model_name, default_steps=20): | |
with gr.Tab(model_name): | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt") | |
enhance = gr.Checkbox(label="Enhance Prompt") | |
translate = gr.Checkbox(label="Translate Prompt") | |
image_width = gr.Number(label="Image Width", value=1024, step=64) | |
image_height = gr.Number(label="Image Height", value=1024, step=64) | |
batch_size = gr.Slider(minimum=1, maximum=4, step=1, label="Batch Size", value=1) | |
num_inference_steps = gr.Slider(minimum=1, maximum=50, step=1, label="Inference Steps", value=default_steps) | |
generate_button = gr.Button("Generate") | |
with gr.Column(): | |
output = gr.Gallery(label="Generated Images") | |
error_output = gr.Textbox(label="Error Message", visible=False) | |
generate_button.click( | |
fn=generate_with_options, | |
inputs=[ | |
gr.Textbox(value=model_name, visible=False), | |
prompt, | |
image_width, | |
image_height, | |
batch_size, | |
num_inference_steps, | |
enhance, | |
translate | |
], | |
outputs=[output, error_output] | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Generation with FLUX Models") | |
create_model_interface("black-forest-labs/FLUX.1-dev", default_steps=20) | |
create_model_interface("black-forest-labs/FLUX.1-schnell") | |
create_model_interface("Pro/black-forest-labs/FLUX.1-schnell") | |
if __name__ == "__main__": | |
port = 7860 # 或者您想使用的其他端口 | |
print("Installing Cloudflared...") | |
install_cloudflared() | |
print("Starting Cloudflared...") | |
start_cloudflared(port) | |
print("Cloudflared started. Launching Gradio app...") | |
demo.launch(server_name="0.0.0.0", server_port=port) |