canadianjosieharrison commited on
Commit
64c4ea7
·
verified ·
1 Parent(s): b93a8f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -94
app.py CHANGED
@@ -1,94 +1,94 @@
1
- import base64
2
- import io
3
- from fastapi import FastAPI, UploadFile, File, HTTPException
4
- import os
5
- import shutil
6
- from PIL import Image
7
- from fastapi.responses import JSONResponse
8
- from semantic_seg_model import segmentation_inference
9
- from similarity_inference import similarity_inference
10
- import json
11
- from gradio_client import Client, file
12
-
13
- app = FastAPI()
14
-
15
- ## Initialize the pipeline
16
- input_images_dir = "image/"
17
- temp_processing_dir = input_images_dir + "processed/"
18
-
19
- # Define a function to handle the POST request at `imageAnalyzer`
20
- @app.post("/imageAnalyzer")
21
- def imageAnalyzer(image: UploadFile = File(...)):
22
- """
23
- This function takes in an image filepath and will return the PolyHaven url addresses of the
24
- top k materials similar to the wall, ceiling, and floor.
25
- """
26
- try:
27
- # load image
28
- image_path = os.path.join(input_images_dir, image.filename)
29
- with open(image_path, "wb") as buffer:
30
- shutil.copyfileobj(image.file, buffer)
31
- image = Image.open(image_path)
32
-
33
- # segment into components
34
- segmentation_inference(image, temp_processing_dir)
35
-
36
- # identify similar materials for each component
37
- matching_urls = similarity_inference(temp_processing_dir)
38
- print(matching_urls)
39
-
40
- # Return the urls in a JSON response
41
- return matching_urls
42
-
43
- except Exception as e:
44
- print(str(e))
45
- raise HTTPException(status_code=500, detail=str(e))
46
-
47
-
48
- client = Client("MykolaL/StableDesign")
49
-
50
- @app.post("/image-render")
51
- def imageRender(prompt: str, image: UploadFile = File(...)):
52
- """
53
- Makes a prediction using the "StableDesign" model hosted on a server.
54
-
55
- Returns:
56
- The prediction result.
57
- """
58
- try:
59
- image_path = os.path.join(input_images_dir, image.filename)
60
- with open(image_path, "wb") as buffer:
61
- shutil.copyfileobj(image.file, buffer)
62
- image = Image.open(image_path)
63
- # Convert PIL image to the required format for the prediction model, if necessary
64
- # This example assumes the model accepts PIL images directly
65
-
66
- result = client.predict(
67
- image=file(image_path),
68
- text=prompt,
69
- num_steps=50,
70
- guidance_scale=10,
71
- seed=1111664444,
72
- strength=0.9,
73
- a_prompt="interior design, 4K, high resolution, photorealistic",
74
- n_prompt="window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner",
75
- img_size=768,
76
- api_name="/on_submit"
77
- )
78
- image_path = result
79
- if not os.path.exists(image_path):
80
- raise HTTPException(status_code=404, detail="Image not found")
81
-
82
- # Open the image file and convert it to base64
83
- with open(image_path, "rb") as img_file:
84
- base64_str = base64.b64encode(img_file.read()).decode('utf-8')
85
-
86
- return JSONResponse(content={"image": base64_str}, status_code=200)
87
- except Exception as e:
88
- print(str(e))
89
- raise HTTPException(status_code=500, detail=str(e))
90
-
91
-
92
- @app.get("/")
93
- def test():
94
- return {"Hello": "World"}
 
1
+ import base64
2
+ import io
3
+ from fastapi import FastAPI, UploadFile, File, HTTPException
4
+ import os
5
+ import shutil
6
+ from PIL import Image
7
+ from fastapi.responses import JSONResponse
8
+ from semantic_seg_model import segmentation_inference
9
+ from similarity_inference import similarity_inference
10
+ import json
11
+ from gradio_client import Client, file
12
+
13
+ app = FastAPI()
14
+
15
+ ## Initialize the pipeline
16
+ input_images_dir = "image/"
17
+ temp_processing_dir = input_images_dir + "processed/"
18
+
19
+ # Define a function to handle the POST request at `imageAnalyzer`
20
+ @app.post("/imageAnalyzer")
21
+ def imageAnalyzer(image: UploadFile = File(...)):
22
+ """
23
+ This function takes in an image filepath and will return the PolyHaven url addresses of the
24
+ top k materials similar to the wall, ceiling, and floor.
25
+ """
26
+ try:
27
+ # load image
28
+ image_path = os.path.join(input_images_dir, image.filename)
29
+ with open(image_path, "wb") as buffer:
30
+ shutil.copyfileobj(image.file, buffer)
31
+ image = Image.open(image_path)
32
+
33
+ # segment into components
34
+ segmentation_inference(image, temp_processing_dir)
35
+
36
+ # identify similar materials for each component
37
+ matching_urls = similarity_inference(temp_processing_dir)
38
+ print(matching_urls)
39
+
40
+ # Return the urls in a JSON response
41
+ return matching_urls
42
+
43
+ except Exception as e:
44
+ print(str(e))
45
+ raise HTTPException(status_code=500, detail=str(e))
46
+
47
+
48
+ client = Client("MykolaL/StableDesign")
49
+
50
+ @app.post("/image-render")
51
+ def imageRender(prompt: str, image: UploadFile = File(...)):
52
+ """
53
+ Makes a prediction using the "StableDesign" model hosted on a server.
54
+
55
+ Returns:
56
+ The prediction result.
57
+ """
58
+ try:
59
+ image_path = os.path.join(input_images_dir, image.filename)
60
+ with open(image_path, "wb") as buffer:
61
+ shutil.copyfileobj(image.file, buffer)
62
+ image = Image.open(image_path)
63
+ # Convert PIL image to the required format for the prediction model, if necessary
64
+ # This example assumes the model accepts PIL images directly
65
+
66
+ result = client.predict(
67
+ image=file(image_path),
68
+ text=prompt,
69
+ num_steps=50,
70
+ guidance_scale=10,
71
+ seed=1111664444,
72
+ strength=0.9,
73
+ a_prompt="interior design, 4K, high resolution, photorealistic",
74
+ n_prompt="window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner",
75
+ img_size=768,
76
+ api_name="/on_submit"
77
+ )
78
+ image_path = result
79
+ if not os.path.exists(image_path):
80
+ raise HTTPException(status_code=404, detail="Image not found")
81
+
82
+ # Open the image file and convert it to base64
83
+ with open(image_path, "rb") as img_file:
84
+ base64_str = base64.b64encode(img_file.read()).decode('utf-8')
85
+
86
+ return JSONResponse(content={"image": base64_str}, status_code=200)
87
+ except Exception as e:
88
+ print(str(e))
89
+ raise HTTPException(status_code=500, detail=str(e))
90
+
91
+
92
+ # @app.get("/")
93
+ # def test():
94
+ # return {"Hello": "World"}