operation / app.py
xhxhdvduenxvxheje's picture
Update app.py
eb2d37f verified
import gradio as gr
import requests
from PIL import Image
import io
import os
import subprocess
import sys
import atexit
API_KEY = os.environ.get("SILICONFLOW_API_KEY")
if not API_KEY:
raise ValueError("请设置SILICONFLOW_API_KEY环境变量")
API_URL = "https://api.siliconflow.cn/v1/image/generations"
TEXT_API_URL = "https://api.siliconflow.cn/v1/chat/completions"
HEADERS = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
CLOUDFLARED_TOKEN = "eyJhIjoiYmRlYmEzN2E5MzRmOTYwMjM4MjZjMzhjZGU4N2U1NDUiLCJ0IjoiNDI4Mjg3NWEtZmE3Yi00OTNmLTkzMmEtYmUxZDQ3MDUyZDIyIiwicyI6Ik9UVmhNRGt4TldJdE9EazJZaTAwWlRGakxXSXdaREl0WldVeE5UVTBZbVprTUdFdyJ9"
cloudflared_process = None
def start_cloudflared(port):
global cloudflared_process
cloudflared_process = subprocess.Popen(["cloudflared", "tunnel", "--url", f"http://localhost:{port}", "run", "--token", CLOUDFLARED_TOKEN])
atexit.register(stop_cloudflared)
def stop_cloudflared():
global cloudflared_process
if cloudflared_process:
cloudflared_process.terminate()
cloudflared_process.wait()
def generate_image(model, prompt, image_size, batch_size=1, num_inference_steps=20):
data = {
"model": model,
"prompt": prompt,
"image_size": image_size,
"batch_size": batch_size,
"num_inference_steps": num_inference_steps
}
print(f"Sending request to API with data: {data}")
response = requests.post(API_URL, headers=HEADERS, json=data)
print(f"API response status code: {response.status_code}")
if response.status_code == 200:
response_json = response.json()
print(f"API response: {response_json}")
if "images" in response_json:
images = []
for img_data in response_json["images"]:
image_url = img_data["url"]
try:
image_response = requests.get(image_url)
image = Image.open(io.BytesIO(image_response.content))
images.append(image)
except Exception as e:
print(f"Error fetching image: {e}")
return images if images else "Error: No images could be fetched"
else:
return "No image data in response"
else:
return f"Error: {response.status_code}, {response.text}"
def use_gemma_model(prompt, task):
data = {
"model": "Qwen/Qwen2.5-7B-Instruct",
"messages": [
{"role": "system", "content": f"You are an AI assistant that helps with {task}. Respond concisely."},
{"role": "user", "content": prompt}
],
"temperature": 0.7,
"max_tokens": 150
}
response = requests.post(TEXT_API_URL, headers=HEADERS, json=data)
if response.status_code == 200:
return response.json()['choices'][0]['message']['content'].strip()
else:
return f"Error: {response.status_code}, {response.text}"
def enhance_prompt(prompt):
return use_gemma_model(f"Enhance this image prompt for better results: {prompt}", "prompt enhancement")
def translate_prompt(prompt):
return use_gemma_model(f"Translate this text to English: {prompt}", "translation")
def generate_with_options(model, prompt, image_width, image_height, batch_size, num_inference_steps, enhance, translate):
try:
if enhance:
prompt = enhance_prompt(prompt)
if translate:
prompt = translate_prompt(prompt)
image_size = f"{image_width}x{image_height}"
result = generate_image(model, prompt, image_size, batch_size, num_inference_steps)
if isinstance(result, str): # 如果返回的是错误消息
return None, result # 返回 None 作为图像,错误消息作为文本
if isinstance(result, list): # 如果返回的是图像列表
return result, None
return result, None # 返回图像和 None 作为错误消息
except Exception as e:
return None, f"An error occurred: {str(e)}"
def create_model_interface(model_name, default_steps=20):
with gr.Tab(model_name):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
enhance = gr.Checkbox(label="Enhance Prompt")
translate = gr.Checkbox(label="Translate Prompt")
image_width = gr.Number(label="Image Width", value=1024, step=64)
image_height = gr.Number(label="Image Height", value=1024, step=64)
batch_size = gr.Slider(minimum=1, maximum=4, step=1, label="Batch Size", value=1)
num_inference_steps = gr.Slider(minimum=1, maximum=50, step=1, label="Inference Steps", value=default_steps)
generate_button = gr.Button("Generate")
with gr.Column():
output = gr.Gallery(label="Generated Images")
error_output = gr.Textbox(label="Error Message", visible=False)
generate_button.click(
fn=generate_with_options,
inputs=[
gr.Textbox(value=model_name, visible=False),
prompt,
image_width,
image_height,
batch_size,
num_inference_steps,
enhance,
translate
],
outputs=[output, error_output]
)
with gr.Blocks() as demo:
gr.Markdown("# Image Generation with FLUX Models")
create_model_interface("black-forest-labs/FLUX.1-dev", default_steps=20)
create_model_interface("black-forest-labs/FLUX.1-schnell")
create_model_interface("Pro/black-forest-labs/FLUX.1-schnell")
if __name__ == "__main__":
port = 7860 # 或者您想使用的其他端口
print("Starting Cloudflared...")
start_cloudflared(port)
print("Cloudflared started. Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=port)