Spaces:
Sleeping
Sleeping
import os | |
from flask import Flask, flash, request, redirect, url_for, session, jsonify | |
from flask_session import Session | |
from werkzeug.utils import secure_filename | |
from example_inference import example_inference | |
from flask import send_from_directory | |
from gradio_client import Client | |
import base64 | |
UPLOAD_FOLDER = 'images' | |
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
app = Flask(__name__) | |
app.config["SESSION_PERMANENT"] = False | |
app.config["SESSION_TYPE"] = "filesystem" | |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
Session(app) | |
client = Client("stabilityai/stable-diffusion-3-medium") | |
def generate_logo(prompt): | |
"""Generates a logo using the provided prompt.""" | |
result = client.predict( | |
prompt=prompt, | |
negative_prompt="", | |
seed=0, | |
randomize_seed=True, | |
width=1024, | |
height=1024, | |
guidance_scale=7, | |
num_inference_steps=50, | |
api_name="/infer", | |
) | |
return result | |
def generate_response(): | |
try: | |
data = request.get_json() | |
if 'prompt' in data : | |
prompt = data['prompt'] | |
response = generate_logo(prompt) | |
print(response, 'response') | |
fileUrl = f'https://stabilityai-stable-diffusion-3-medium.hf.space/file={response[0]}' | |
return jsonify(fileUrl) | |
else: | |
return jsonify({"error": "Missing required parameters"}), 400 | |
except Exception as e: | |
return jsonify({"Error": str(e)}), 500 | |
def allowed_file(filename): | |
return '.' in filename and \ | |
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
def upload_file(transparent, color): | |
transparent = True if transparent == "true" else False | |
if request.method == 'POST': | |
if 'file' not in request.files: | |
flash('No file part') | |
return {"status": "Failed", "message": "Please Provide file name(file)."} | |
file = request.files['file'] | |
if file.filename == '': | |
flash('No selected file') | |
return {"status": "Failed", "message": "Filename Not Found."} | |
if file and allowed_file(file.filename): | |
img_data = file.read() | |
img_base64 = base64.b64encode(img_data) | |
if isinstance(img_base64, bytes): | |
img_base64 = img_base64.decode('utf-8') | |
print('yes') | |
image_bytes = base64.b64decode(img_base64) | |
rm_image_path = example_inference(image_bytes, transparent, color) | |
return rm_image_path | |
return { | |
"message": "Get Request not allowed" | |
} | |
if __name__ == '__main__': | |
# http_server = WSGIServer(('', 8000), app) | |
# http_server.serve_forever() | |
app.debug = True | |
app.run(port=8000) |