Spaces:
Runtime error
Runtime error
File size: 4,757 Bytes
64c4ea7 e5685f4 64c4ea7 b0ec6f8 64c4ea7 e5685f4 64c4ea7 e5685f4 94c0acf 64c4ea7 4148997 94c0acf 64c4ea7 94c0acf e5685f4 94c0acf 64c4ea7 94c0acf e5685f4 94c0acf 64c4ea7 e5685f4 64c4ea7 e5685f4 64c4ea7 e5685f4 64c4ea7 e5685f4 94c0acf e5685f4 94c0acf e5685f4 64c4ea7 e5685f4 64c4ea7 cf80bf3 e5685f4 cf80bf3 e5685f4 64c4ea7 e5685f4 64c4ea7 e5685f4 cf80bf3 64c4ea7 94c0acf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import base64
import io
from fastapi import FastAPI, UploadFile, File, HTTPException
import os
from PIL import Image
from fastapi.responses import JSONResponse
from semantic_seg_model import segmentation_inference
from similarity_inference import similarity_inference
from gradio_client import Client, file
from datetime import datetime
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(docs_url="/")
allowed_origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
)
## Initialize the pipeline
input_images_dir = "image/"
temp_processing_dir = input_images_dir + "processed/"
# Define a function to handle the POST request at `image-analyzer`
@app.post("/image-analyzer")
async def image_analyzer(image: UploadFile = File(...)):
"""
This function takes in an image filepath and will return the PolyHaven url addresses of the
top k materials similar to the wall, ceiling, and floor.
"""
try:
# delete contents of image folder
for filename in os.listdir(temp_processing_dir):
file_path = os.path.join(temp_processing_dir, filename)
try:
os.remove(file_path) # Remove the file
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
# load image
contents = await image.read()
if contents.startswith(b"data:image/png;base64"):
# Remove the prefix "data:image/png;base64,"
image_data = contents.split(b";base64,")[1]
# Decode base64 data
decoded_image = base64.b64decode(image_data)
img = Image.open(io.BytesIO(decoded_image))
else:
img = Image.open(image.file)
print("image loaded successfully. Processing image for segmentation and similarity inference...", datetime.now())
# segment into components
segmentation_inference(img, temp_processing_dir)
print("image segmented successfully. Starting similarity inference...", datetime.now())
# identify similar materials for each component
matching_textures = similarity_inference(temp_processing_dir)
print("done", datetime.now())
# Return the urls in a JSON response
return matching_textures
except Exception as e:
print(str(e))
raise HTTPException(status_code=500, detail=str(e))
client = Client("MykolaL/StableDesign")
@app.post("/image-render")
async def image_render(prompt: str, image: UploadFile = File(...)):
"""
Makes a prediction using the "StableDesign" model hosted on a server.
Returns:
The prediction result.
"""
try:
print(f"recieved prompt: {prompt} and image: {image}")
image_path = os.path.join(input_images_dir, image.filename+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")+".png")
contents = await image.read()
if contents.startswith(b"data:image/png;base64"):
# Remove the prefix "data:image/png;base64,"
image_data = contents.split(b";base64,")[1]
# Decode base64 data
decoded_image = base64.b64decode(image_data)
img = Image.open(io.BytesIO(decoded_image))
else:
img = Image.open(image.file)
# Convert image to grayscale
grayscale_image = img.convert('L')
# Save the processed image to the specified path
grayscale_image.save(image_path)
result = client.predict(
image=file(image_path),
text=prompt,
num_steps=50,
guidance_scale=10,
seed=1111664444,
strength=1,
a_prompt="interior design, 4K, high resolution, photorealistic",
n_prompt="window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner",
img_size=768,
api_name="/on_submit"
)
new_image_path = result
# delete image_path
os.remove(image_path)
if not os.path.exists(new_image_path):
raise HTTPException(status_code=404, detail="Image not found")
# Open the image file and convert it to base64
with open(new_image_path, "rb") as img_file:
base64_str = base64.b64encode(img_file.read()).decode('utf-8')
os.remove(new_image_path)
return JSONResponse(content={"image": base64_str}, status_code=200)
except Exception as e:
print(str(e))
raise HTTPException(status_code=500, detail=str(e))
|