xhxhdvduenxvxheje commited on
Commit
779833c
1 Parent(s): 6d17964

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ from PIL import Image
4
+ import io
5
+ import os
6
+ import subprocess
7
+ import sys
8
+ import atexit
9
+ import time
10
+
11
+ API_KEY = os.environ.get("SILICONFLOW_API_KEY")
12
+ if not API_KEY:
13
+ raise ValueError("请设置SILICONFLOW_API_KEY环境变量")
14
+
15
+ API_URL = "https://api.siliconflow.cn/v1/image/generations"
16
+ TEXT_API_URL = "https://api.siliconflow.cn/v1/chat/completions"
17
+ HEADERS = {
18
+ "accept": "application/json",
19
+ "content-type": "application/json",
20
+ "Authorization": f"Bearer {API_KEY}"
21
+ }
22
+
23
+ # Cloudflared 安装和设置
24
+ def install_cloudflared():
25
+ if sys.platform.startswith('linux'):
26
+ subprocess.run(["curl", "-L", "--output", "cloudflared", "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64"])
27
+ subprocess.run(["chmod", "+x", "cloudflared"])
28
+ elif sys.platform == 'darwin':
29
+ subprocess.run(["brew", "install", "cloudflare/cloudflare/cloudflared"])
30
+ elif sys.platform == 'win32':
31
+ subprocess.run(["powershell", "-Command", "Invoke-WebRequest -Uri https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-windows-amd64.exe -OutFile cloudflared.exe"])
32
+ else:
33
+ raise OSError("Unsupported operating system")
34
+
35
+ cloudflared_process = None
36
+
37
+ def start_cloudflared(port):
38
+ global cloudflared_process
39
+ if sys.platform.startswith('linux') or sys.platform == 'darwin':
40
+ cloudflared_process = subprocess.Popen(["./cloudflared", "tunnel", "--url", f"http://localhost:{port}"])
41
+ elif sys.platform == 'win32':
42
+ cloudflared_process = subprocess.Popen(["cloudflared.exe", "tunnel", "--url", f"http://localhost:{port}"])
43
+ atexit.register(stop_cloudflared)
44
+
45
+ def stop_cloudflared():
46
+ global cloudflared_process
47
+ if cloudflared_process:
48
+ cloudflared_process.terminate()
49
+ cloudflared_process.wait()
50
+
51
+ # 其余函数保持不变
52
+ def generate_image(model, prompt, image_size, batch_size=1, num_inference_steps=20):
53
+ data = {
54
+ "model": model,
55
+ "prompt": prompt,
56
+ "image_size": image_size,
57
+ "batch_size": batch_size,
58
+ "num_inference_steps": num_inference_steps
59
+ }
60
+
61
+ print(f"Sending request to API with data: {data}")
62
+ response = requests.post(API_URL, headers=HEADERS, json=data)
63
+ print(f"API response status code: {response.status_code}")
64
+
65
+ if response.status_code == 200:
66
+ response_json = response.json()
67
+ print(f"API response: {response_json}")
68
+ if "images" in response_json:
69
+ images = []
70
+ for img_data in response_json["images"]:
71
+ image_url = img_data["url"]
72
+ try:
73
+ image_response = requests.get(image_url)
74
+ image = Image.open(io.BytesIO(image_response.content))
75
+ images.append(image)
76
+ except Exception as e:
77
+ print(f"Error fetching image: {e}")
78
+ return images if images else "Error: No images could be fetched"
79
+ else:
80
+ return "No image data in response"
81
+ else:
82
+ return f"Error: {response.status_code}, {response.text}"
83
+
84
+ def use_gemma_model(prompt, task):
85
+ data = {
86
+ "model": "google/gemma-2-9b-it",
87
+ "messages": [
88
+ {"role": "system", "content": f"You are an AI assistant that helps with {task}. Respond concisely."},
89
+ {"role": "user", "content": prompt}
90
+ ],
91
+ "temperature": 0.7,
92
+ "max_tokens": 150
93
+ }
94
+ response = requests.post(TEXT_API_URL, headers=HEADERS, json=data)
95
+ if response.status_code == 200:
96
+ return response.json()['choices'][0]['message']['content'].strip()
97
+ else:
98
+ return f"Error: {response.status_code}, {response.text}"
99
+
100
+ def enhance_prompt(prompt):
101
+ return use_gemma_model(f"Enhance this image prompt for better results: {prompt}", "prompt enhancement")
102
+
103
+ def translate_prompt(prompt):
104
+ return use_gemma_model(f"Translate this text to English: {prompt}", "translation")
105
+
106
+ def generate_with_options(model, prompt, image_width, image_height, batch_size, num_inference_steps, enhance, translate):
107
+ try:
108
+ if enhance:
109
+ prompt = enhance_prompt(prompt)
110
+ if translate:
111
+ prompt = translate_prompt(prompt)
112
+ image_size = f"{image_width}x{image_height}"
113
+ result = generate_image(model, prompt, image_size, batch_size, num_inference_steps)
114
+ if isinstance(result, str): # 如果返回的是错误消息
115
+ return None, result # 返回 None 作为图像,错误消息作为文本
116
+ if isinstance(result, list): # 如果返回的是图像列表
117
+ return result, None
118
+ return result, None # 返回图像和 None 作为错误消息
119
+ except Exception as e:
120
+ return None, f"An error occurred: {str(e)}"
121
+
122
+ def create_model_interface(model_name, default_steps=20):
123
+ with gr.Tab(model_name):
124
+ with gr.Row():
125
+ with gr.Column():
126
+ prompt = gr.Textbox(label="Prompt")
127
+ enhance = gr.Checkbox(label="Enhance Prompt")
128
+ translate = gr.Checkbox(label="Translate Prompt")
129
+ image_width = gr.Number(label="Image Width", value=1024, step=64)
130
+ image_height = gr.Number(label="Image Height", value=1024, step=64)
131
+ batch_size = gr.Slider(minimum=1, maximum=4, step=1, label="Batch Size", value=1)
132
+ num_inference_steps = gr.Slider(minimum=1, maximum=50, step=1, label="Inference Steps", value=default_steps)
133
+ generate_button = gr.Button("Generate")
134
+ with gr.Column():
135
+ output = gr.Gallery(label="Generated Images")
136
+ error_output = gr.Textbox(label="Error Message", visible=False)
137
+
138
+ generate_button.click(
139
+ fn=generate_with_options,
140
+ inputs=[
141
+ gr.Textbox(value=model_name, visible=False),
142
+ prompt,
143
+ image_width,
144
+ image_height,
145
+ batch_size,
146
+ num_inference_steps,
147
+ enhance,
148
+ translate
149
+ ],
150
+ outputs=[output, error_output]
151
+ )
152
+
153
+ with gr.Blocks() as demo:
154
+ gr.Markdown("# Image Generation with FLUX Models")
155
+
156
+ create_model_interface("black-forest-labs/FLUX.1-dev", default_steps=20)
157
+ create_model_interface("black-forest-labs/FLUX.1-schnell")
158
+ create_model_interface("Pro/black-forest-labs/FLUX.1-schnell")
159
+
160
+ if __name__ == "__main__":
161
+ port = 7860 # 或者您想使用的其他端口
162
+ print("Installing Cloudflared...")
163
+ install_cloudflared()
164
+ print("Starting Cloudflared...")
165
+ start_cloudflared(port)
166
+ print("Cloudflared started. Launching Gradio app...")
167
+ demo.launch(server_name="0.0.0.0", server_port=port)