File size: 4,757 Bytes
64c4ea7
 
 
 
 
 
 
 
 
e5685f4
 
 
64c4ea7
b0ec6f8
64c4ea7
e5685f4
 
 
 
 
 
 
 
 
64c4ea7
 
 
 
e5685f4
 
94c0acf
64c4ea7
 
 
 
 
4148997
 
 
 
 
 
 
94c0acf
64c4ea7
94c0acf
 
 
 
 
 
 
 
 
 
 
e5685f4
94c0acf
64c4ea7
94c0acf
e5685f4
94c0acf
64c4ea7
e5685f4
 
64c4ea7
 
e5685f4
64c4ea7
 
 
 
 
 
 
 
 
e5685f4
64c4ea7
 
 
 
 
 
 
e5685f4
 
 
94c0acf
 
 
 
 
 
 
 
 
e5685f4
94c0acf
e5685f4
 
64c4ea7
 
 
 
 
 
e5685f4
64c4ea7
 
 
 
 
cf80bf3
e5685f4
cf80bf3
 
e5685f4
64c4ea7
 
 
e5685f4
64c4ea7
e5685f4
cf80bf3
64c4ea7
 
 
 
94c0acf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import base64
import io
from fastapi import FastAPI, UploadFile, File, HTTPException
import os
from PIL import Image
from fastapi.responses import JSONResponse
from semantic_seg_model import segmentation_inference
from similarity_inference import similarity_inference
from gradio_client import Client, file
from datetime import datetime
from fastapi.middleware.cors import CORSMiddleware


app = FastAPI(docs_url="/")

allowed_origins = ["*"]
app.add_middleware(
    CORSMiddleware,
    allow_origins=allowed_origins,
    allow_credentials=True,
    allow_methods=["GET", "POST", "PUT", "DELETE"],
    allow_headers=["*"],
)

## Initialize the pipeline
input_images_dir = "image/"
temp_processing_dir = input_images_dir + "processed/"

# Define a function to handle the POST request at `image-analyzer`
@app.post("/image-analyzer")
async def image_analyzer(image: UploadFile = File(...)):
    """
    This function takes in an image filepath and will return the PolyHaven url addresses of the 
    top k materials similar to the wall, ceiling, and floor.
    """
    try:
        # delete contents of image folder
        for filename in os.listdir(temp_processing_dir):
            file_path = os.path.join(temp_processing_dir, filename)
            try:
                os.remove(file_path)  # Remove the file
            except Exception as e:
                print(f"Failed to delete {file_path}. Reason: {e}")

        # load image
        contents = await image.read()
        if contents.startswith(b"data:image/png;base64"):
            # Remove the prefix "data:image/png;base64,"
            image_data = contents.split(b";base64,")[1]
            # Decode base64 data
            decoded_image = base64.b64decode(image_data)
            img = Image.open(io.BytesIO(decoded_image))

        else:
            img = Image.open(image.file)

        print("image loaded successfully. Processing image for segmentation and similarity inference...", datetime.now())

        # segment into components
        segmentation_inference(img, temp_processing_dir)
        print("image segmented successfully. Starting similarity inference...", datetime.now())

        # identify similar materials for each component
        matching_textures = similarity_inference(temp_processing_dir)
        print("done", datetime.now())

        # Return the urls in a JSON response
        return matching_textures
    
    except Exception as e:
        print(str(e))
        raise HTTPException(status_code=500, detail=str(e))
    

client = Client("MykolaL/StableDesign")

@app.post("/image-render")
async def image_render(prompt: str, image: UploadFile = File(...)):
    """
    Makes a prediction using the "StableDesign" model hosted on a server.

    Returns:
        The prediction result.
    """
    try:
        print(f"recieved prompt: {prompt} and image: {image}")
        image_path = os.path.join(input_images_dir, image.filename+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")+".png")
        contents = await image.read()
        if contents.startswith(b"data:image/png;base64"):
            # Remove the prefix "data:image/png;base64,"
            image_data = contents.split(b";base64,")[1]
            # Decode base64 data
            decoded_image = base64.b64decode(image_data)
            img = Image.open(io.BytesIO(decoded_image))

        else:
            img = Image.open(image.file)
        # Convert image to grayscale
        grayscale_image = img.convert('L')
        # Save the processed image to the specified path
        grayscale_image.save(image_path)
        result = client.predict(
                image=file(image_path),
                text=prompt,
                num_steps=50,
                guidance_scale=10,
                seed=1111664444,
                strength=1,
                a_prompt="interior design, 4K, high resolution, photorealistic",
                n_prompt="window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner",
                img_size=768,
                api_name="/on_submit"
        )
        
        new_image_path = result
        # delete image_path
        os.remove(image_path)
        if not os.path.exists(new_image_path):
            raise HTTPException(status_code=404, detail="Image not found")
        
        # Open the image file and convert it to base64
        with open(new_image_path, "rb") as img_file:
            base64_str = base64.b64encode(img_file.read()).decode('utf-8')
            
        os.remove(new_image_path)
        return JSONResponse(content={"image": base64_str}, status_code=200)
    except Exception as e:
        print(str(e))
        raise HTTPException(status_code=500, detail=str(e))