operation / app.py
xhxhdvduenxvxheje's picture
Create app.py
779833c verified
raw
history blame
6.8 kB
import gradio as gr
import requests
from PIL import Image
import io
import os
import subprocess
import sys
import atexit
import time
API_KEY = os.environ.get("SILICONFLOW_API_KEY")
if not API_KEY:
raise ValueError("请设置SILICONFLOW_API_KEY环境变量")
API_URL = "https://api.siliconflow.cn/v1/image/generations"
TEXT_API_URL = "https://api.siliconflow.cn/v1/chat/completions"
HEADERS = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
# Cloudflared 安装和设置
def install_cloudflared():
if sys.platform.startswith('linux'):
subprocess.run(["curl", "-L", "--output", "cloudflared", "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64"])
subprocess.run(["chmod", "+x", "cloudflared"])
elif sys.platform == 'darwin':
subprocess.run(["brew", "install", "cloudflare/cloudflare/cloudflared"])
elif sys.platform == 'win32':
subprocess.run(["powershell", "-Command", "Invoke-WebRequest -Uri https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-windows-amd64.exe -OutFile cloudflared.exe"])
else:
raise OSError("Unsupported operating system")
cloudflared_process = None
def start_cloudflared(port):
global cloudflared_process
if sys.platform.startswith('linux') or sys.platform == 'darwin':
cloudflared_process = subprocess.Popen(["./cloudflared", "tunnel", "--url", f"http://localhost:{port}"])
elif sys.platform == 'win32':
cloudflared_process = subprocess.Popen(["cloudflared.exe", "tunnel", "--url", f"http://localhost:{port}"])
atexit.register(stop_cloudflared)
def stop_cloudflared():
global cloudflared_process
if cloudflared_process:
cloudflared_process.terminate()
cloudflared_process.wait()
# 其余函数保持不变
def generate_image(model, prompt, image_size, batch_size=1, num_inference_steps=20):
data = {
"model": model,
"prompt": prompt,
"image_size": image_size,
"batch_size": batch_size,
"num_inference_steps": num_inference_steps
}
print(f"Sending request to API with data: {data}")
response = requests.post(API_URL, headers=HEADERS, json=data)
print(f"API response status code: {response.status_code}")
if response.status_code == 200:
response_json = response.json()
print(f"API response: {response_json}")
if "images" in response_json:
images = []
for img_data in response_json["images"]:
image_url = img_data["url"]
try:
image_response = requests.get(image_url)
image = Image.open(io.BytesIO(image_response.content))
images.append(image)
except Exception as e:
print(f"Error fetching image: {e}")
return images if images else "Error: No images could be fetched"
else:
return "No image data in response"
else:
return f"Error: {response.status_code}, {response.text}"
def use_gemma_model(prompt, task):
data = {
"model": "google/gemma-2-9b-it",
"messages": [
{"role": "system", "content": f"You are an AI assistant that helps with {task}. Respond concisely."},
{"role": "user", "content": prompt}
],
"temperature": 0.7,
"max_tokens": 150
}
response = requests.post(TEXT_API_URL, headers=HEADERS, json=data)
if response.status_code == 200:
return response.json()['choices'][0]['message']['content'].strip()
else:
return f"Error: {response.status_code}, {response.text}"
def enhance_prompt(prompt):
return use_gemma_model(f"Enhance this image prompt for better results: {prompt}", "prompt enhancement")
def translate_prompt(prompt):
return use_gemma_model(f"Translate this text to English: {prompt}", "translation")
def generate_with_options(model, prompt, image_width, image_height, batch_size, num_inference_steps, enhance, translate):
try:
if enhance:
prompt = enhance_prompt(prompt)
if translate:
prompt = translate_prompt(prompt)
image_size = f"{image_width}x{image_height}"
result = generate_image(model, prompt, image_size, batch_size, num_inference_steps)
if isinstance(result, str): # 如果返回的是错误消息
return None, result # 返回 None 作为图像,错误消息作为文本
if isinstance(result, list): # 如果返回的是图像列表
return result, None
return result, None # 返回图像和 None 作为错误消息
except Exception as e:
return None, f"An error occurred: {str(e)}"
def create_model_interface(model_name, default_steps=20):
with gr.Tab(model_name):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
enhance = gr.Checkbox(label="Enhance Prompt")
translate = gr.Checkbox(label="Translate Prompt")
image_width = gr.Number(label="Image Width", value=1024, step=64)
image_height = gr.Number(label="Image Height", value=1024, step=64)
batch_size = gr.Slider(minimum=1, maximum=4, step=1, label="Batch Size", value=1)
num_inference_steps = gr.Slider(minimum=1, maximum=50, step=1, label="Inference Steps", value=default_steps)
generate_button = gr.Button("Generate")
with gr.Column():
output = gr.Gallery(label="Generated Images")
error_output = gr.Textbox(label="Error Message", visible=False)
generate_button.click(
fn=generate_with_options,
inputs=[
gr.Textbox(value=model_name, visible=False),
prompt,
image_width,
image_height,
batch_size,
num_inference_steps,
enhance,
translate
],
outputs=[output, error_output]
)
with gr.Blocks() as demo:
gr.Markdown("# Image Generation with FLUX Models")
create_model_interface("black-forest-labs/FLUX.1-dev", default_steps=20)
create_model_interface("black-forest-labs/FLUX.1-schnell")
create_model_interface("Pro/black-forest-labs/FLUX.1-schnell")
if __name__ == "__main__":
port = 7860 # 或者您想使用的其他端口
print("Installing Cloudflared...")
install_cloudflared()
print("Starting Cloudflared...")
start_cloudflared(port)
print("Cloudflared started. Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=port)