File size: 5,991 Bytes
779833c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb2d37f
779833c
 
 
 
 
eb2d37f
779833c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb2d37f
779833c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import requests
from PIL import Image
import io
import os
import subprocess
import sys
import atexit

API_KEY = os.environ.get("SILICONFLOW_API_KEY")
if not API_KEY:
    raise ValueError("请设置SILICONFLOW_API_KEY环境变量")

API_URL = "https://api.siliconflow.cn/v1/image/generations"
TEXT_API_URL = "https://api.siliconflow.cn/v1/chat/completions"
HEADERS = {
    "accept": "application/json",
    "content-type": "application/json",
    "Authorization": f"Bearer {API_KEY}"
}

CLOUDFLARED_TOKEN = "eyJhIjoiYmRlYmEzN2E5MzRmOTYwMjM4MjZjMzhjZGU4N2U1NDUiLCJ0IjoiNDI4Mjg3NWEtZmE3Yi00OTNmLTkzMmEtYmUxZDQ3MDUyZDIyIiwicyI6Ik9UVmhNRGt4TldJdE9EazJZaTAwWlRGakxXSXdaREl0WldVeE5UVTBZbVprTUdFdyJ9"

cloudflared_process = None

def start_cloudflared(port):
    global cloudflared_process
    cloudflared_process = subprocess.Popen(["cloudflared", "tunnel", "--url", f"http://localhost:{port}", "run", "--token", CLOUDFLARED_TOKEN])
    atexit.register(stop_cloudflared)

def stop_cloudflared():
    global cloudflared_process
    if cloudflared_process:
        cloudflared_process.terminate()
        cloudflared_process.wait()

def generate_image(model, prompt, image_size, batch_size=1, num_inference_steps=20):
    data = {
        "model": model,
        "prompt": prompt,
        "image_size": image_size,
        "batch_size": batch_size,
        "num_inference_steps": num_inference_steps
    }

    print(f"Sending request to API with data: {data}")
    response = requests.post(API_URL, headers=HEADERS, json=data)
    print(f"API response status code: {response.status_code}")

    if response.status_code == 200:
        response_json = response.json()
        print(f"API response: {response_json}")
        if "images" in response_json:
            images = []
            for img_data in response_json["images"]:
                image_url = img_data["url"]
                try:
                    image_response = requests.get(image_url)
                    image = Image.open(io.BytesIO(image_response.content))
                    images.append(image)
                except Exception as e:
                    print(f"Error fetching image: {e}")
            return images if images else "Error: No images could be fetched"
        else:
            return "No image data in response"
    else:
        return f"Error: {response.status_code}, {response.text}"

def use_gemma_model(prompt, task):
    data = {
        "model": "Qwen/Qwen2.5-7B-Instruct",
        "messages": [
            {"role": "system", "content": f"You are an AI assistant that helps with {task}. Respond concisely."},
            {"role": "user", "content": prompt}
        ],
        "temperature": 0.7,
        "max_tokens": 150
    }
    response = requests.post(TEXT_API_URL, headers=HEADERS, json=data)
    if response.status_code == 200:
        return response.json()['choices'][0]['message']['content'].strip()
    else:
        return f"Error: {response.status_code}, {response.text}"

def enhance_prompt(prompt):
    return use_gemma_model(f"Enhance this image prompt for better results: {prompt}", "prompt enhancement")

def translate_prompt(prompt):
    return use_gemma_model(f"Translate this text to English: {prompt}", "translation")

def generate_with_options(model, prompt, image_width, image_height, batch_size, num_inference_steps, enhance, translate):
    try:
        if enhance:
            prompt = enhance_prompt(prompt)
        if translate:
            prompt = translate_prompt(prompt)
        image_size = f"{image_width}x{image_height}"
        result = generate_image(model, prompt, image_size, batch_size, num_inference_steps)
        if isinstance(result, str):  # 如果返回的是错误消息
            return None, result  # 返回 None 作为图像,错误消息作为文本
        if isinstance(result, list):  # 如果返回的是图像列表
            return result, None
        return result, None  # 返回图像和 None 作为错误消息
    except Exception as e:
        return None, f"An error occurred: {str(e)}"

def create_model_interface(model_name, default_steps=20):
    with gr.Tab(model_name):
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt")
                enhance = gr.Checkbox(label="Enhance Prompt")
                translate = gr.Checkbox(label="Translate Prompt")
                image_width = gr.Number(label="Image Width", value=1024, step=64)
                image_height = gr.Number(label="Image Height", value=1024, step=64)
                batch_size = gr.Slider(minimum=1, maximum=4, step=1, label="Batch Size", value=1)
                num_inference_steps = gr.Slider(minimum=1, maximum=50, step=1, label="Inference Steps", value=default_steps)
                generate_button = gr.Button("Generate")
            with gr.Column():
                output = gr.Gallery(label="Generated Images")
                error_output = gr.Textbox(label="Error Message", visible=False)
        
        generate_button.click(
            fn=generate_with_options,
            inputs=[
                gr.Textbox(value=model_name, visible=False),
                prompt,
                image_width,
                image_height,
                batch_size,
                num_inference_steps,
                enhance,
                translate
            ],
            outputs=[output, error_output]
        )

with gr.Blocks() as demo:
    gr.Markdown("# Image Generation with FLUX Models")
    
    create_model_interface("black-forest-labs/FLUX.1-dev", default_steps=20)
    create_model_interface("black-forest-labs/FLUX.1-schnell")
    create_model_interface("Pro/black-forest-labs/FLUX.1-schnell")

if __name__ == "__main__":
    port = 7860  # 或者您想使用的其他端口
    print("Starting Cloudflared...")
    start_cloudflared(port)
    print("Cloudflared started. Launching Gradio app...")
    demo.launch(server_name="0.0.0.0", server_port=port)