Spaces:
Building
Building
File size: 5,991 Bytes
779833c eb2d37f 779833c eb2d37f 779833c eb2d37f 779833c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import requests
from PIL import Image
import io
import os
import subprocess
import sys
import atexit
API_KEY = os.environ.get("SILICONFLOW_API_KEY")
if not API_KEY:
raise ValueError("请设置SILICONFLOW_API_KEY环境变量")
API_URL = "https://api.siliconflow.cn/v1/image/generations"
TEXT_API_URL = "https://api.siliconflow.cn/v1/chat/completions"
HEADERS = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
CLOUDFLARED_TOKEN = "eyJhIjoiYmRlYmEzN2E5MzRmOTYwMjM4MjZjMzhjZGU4N2U1NDUiLCJ0IjoiNDI4Mjg3NWEtZmE3Yi00OTNmLTkzMmEtYmUxZDQ3MDUyZDIyIiwicyI6Ik9UVmhNRGt4TldJdE9EazJZaTAwWlRGakxXSXdaREl0WldVeE5UVTBZbVprTUdFdyJ9"
cloudflared_process = None
def start_cloudflared(port):
global cloudflared_process
cloudflared_process = subprocess.Popen(["cloudflared", "tunnel", "--url", f"http://localhost:{port}", "run", "--token", CLOUDFLARED_TOKEN])
atexit.register(stop_cloudflared)
def stop_cloudflared():
global cloudflared_process
if cloudflared_process:
cloudflared_process.terminate()
cloudflared_process.wait()
def generate_image(model, prompt, image_size, batch_size=1, num_inference_steps=20):
data = {
"model": model,
"prompt": prompt,
"image_size": image_size,
"batch_size": batch_size,
"num_inference_steps": num_inference_steps
}
print(f"Sending request to API with data: {data}")
response = requests.post(API_URL, headers=HEADERS, json=data)
print(f"API response status code: {response.status_code}")
if response.status_code == 200:
response_json = response.json()
print(f"API response: {response_json}")
if "images" in response_json:
images = []
for img_data in response_json["images"]:
image_url = img_data["url"]
try:
image_response = requests.get(image_url)
image = Image.open(io.BytesIO(image_response.content))
images.append(image)
except Exception as e:
print(f"Error fetching image: {e}")
return images if images else "Error: No images could be fetched"
else:
return "No image data in response"
else:
return f"Error: {response.status_code}, {response.text}"
def use_gemma_model(prompt, task):
data = {
"model": "Qwen/Qwen2.5-7B-Instruct",
"messages": [
{"role": "system", "content": f"You are an AI assistant that helps with {task}. Respond concisely."},
{"role": "user", "content": prompt}
],
"temperature": 0.7,
"max_tokens": 150
}
response = requests.post(TEXT_API_URL, headers=HEADERS, json=data)
if response.status_code == 200:
return response.json()['choices'][0]['message']['content'].strip()
else:
return f"Error: {response.status_code}, {response.text}"
def enhance_prompt(prompt):
return use_gemma_model(f"Enhance this image prompt for better results: {prompt}", "prompt enhancement")
def translate_prompt(prompt):
return use_gemma_model(f"Translate this text to English: {prompt}", "translation")
def generate_with_options(model, prompt, image_width, image_height, batch_size, num_inference_steps, enhance, translate):
try:
if enhance:
prompt = enhance_prompt(prompt)
if translate:
prompt = translate_prompt(prompt)
image_size = f"{image_width}x{image_height}"
result = generate_image(model, prompt, image_size, batch_size, num_inference_steps)
if isinstance(result, str): # 如果返回的是错误消息
return None, result # 返回 None 作为图像,错误消息作为文本
if isinstance(result, list): # 如果返回的是图像列表
return result, None
return result, None # 返回图像和 None 作为错误消息
except Exception as e:
return None, f"An error occurred: {str(e)}"
def create_model_interface(model_name, default_steps=20):
with gr.Tab(model_name):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
enhance = gr.Checkbox(label="Enhance Prompt")
translate = gr.Checkbox(label="Translate Prompt")
image_width = gr.Number(label="Image Width", value=1024, step=64)
image_height = gr.Number(label="Image Height", value=1024, step=64)
batch_size = gr.Slider(minimum=1, maximum=4, step=1, label="Batch Size", value=1)
num_inference_steps = gr.Slider(minimum=1, maximum=50, step=1, label="Inference Steps", value=default_steps)
generate_button = gr.Button("Generate")
with gr.Column():
output = gr.Gallery(label="Generated Images")
error_output = gr.Textbox(label="Error Message", visible=False)
generate_button.click(
fn=generate_with_options,
inputs=[
gr.Textbox(value=model_name, visible=False),
prompt,
image_width,
image_height,
batch_size,
num_inference_steps,
enhance,
translate
],
outputs=[output, error_output]
)
with gr.Blocks() as demo:
gr.Markdown("# Image Generation with FLUX Models")
create_model_interface("black-forest-labs/FLUX.1-dev", default_steps=20)
create_model_interface("black-forest-labs/FLUX.1-schnell")
create_model_interface("Pro/black-forest-labs/FLUX.1-schnell")
if __name__ == "__main__":
port = 7860 # 或者您想使用的其他端口
print("Starting Cloudflared...")
start_cloudflared(port)
print("Cloudflared started. Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=port) |