canadianjosieharrison commited on
Commit
e5685f4
·
verified ·
1 Parent(s): b0ec6f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -28
app.py CHANGED
@@ -7,38 +7,49 @@ from PIL import Image
7
  from fastapi.responses import JSONResponse
8
  from semantic_seg_model import segmentation_inference
9
  from similarity_inference import similarity_inference
10
- import json
11
  from gradio_client import Client, file
 
 
 
12
 
13
  app = FastAPI(docs_url="/")
14
 
 
 
 
 
 
 
 
 
 
15
  ## Initialize the pipeline
16
  input_images_dir = "image/"
17
  temp_processing_dir = input_images_dir + "processed/"
18
 
19
- # Define a function to handle the POST request at `imageAnalyzer`
20
- @app.post("/imageAnalyzer")
21
- def imageAnalyzer(image: UploadFile = File(...)):
22
  """
23
  This function takes in an image filepath and will return the PolyHaven url addresses of the
24
  top k materials similar to the wall, ceiling, and floor.
25
  """
26
  try:
27
  # load image
28
- image_path = os.path.join(input_images_dir, image.filename)
29
  with open(image_path, "wb") as buffer:
30
  shutil.copyfileobj(image.file, buffer)
31
  image = Image.open(image_path)
32
-
33
  # segment into components
34
  segmentation_inference(image, temp_processing_dir)
35
-
36
  # identify similar materials for each component
37
- matching_urls = similarity_inference(temp_processing_dir)
38
- print(matching_urls)
39
 
40
  # Return the urls in a JSON response
41
- return matching_urls
42
 
43
  except Exception as e:
44
  print(str(e))
@@ -48,7 +59,7 @@ def imageAnalyzer(image: UploadFile = File(...)):
48
  client = Client("MykolaL/StableDesign")
49
 
50
  @app.post("/image-render")
51
- def imageRender(prompt: str, image: UploadFile = File(...)):
52
  """
53
  Makes a prediction using the "StableDesign" model hosted on a server.
54
 
@@ -56,39 +67,41 @@ def imageRender(prompt: str, image: UploadFile = File(...)):
56
  The prediction result.
57
  """
58
  try:
59
- image_path = os.path.join(input_images_dir, image.filename)
60
- with open(image_path, "wb") as buffer:
61
- shutil.copyfileobj(image.file, buffer)
62
- image = Image.open(image_path)
63
- # Convert PIL image to the required format for the prediction model, if necessary
64
- # This example assumes the model accepts PIL images directly
65
-
 
 
 
 
 
66
  result = client.predict(
67
  image=file(image_path),
68
  text=prompt,
69
  num_steps=50,
70
  guidance_scale=10,
71
  seed=1111664444,
72
- strength=0.9,
73
  a_prompt="interior design, 4K, high resolution, photorealistic",
74
  n_prompt="window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner",
75
  img_size=768,
76
  api_name="/on_submit"
77
  )
78
- image_path = result
79
- if not os.path.exists(image_path):
 
80
  raise HTTPException(status_code=404, detail="Image not found")
81
 
82
  # Open the image file and convert it to base64
83
- with open(image_path, "rb") as img_file:
84
  base64_str = base64.b64encode(img_file.read()).decode('utf-8')
85
-
86
  return JSONResponse(content={"image": base64_str}, status_code=200)
87
  except Exception as e:
88
  print(str(e))
89
  raise HTTPException(status_code=500, detail=str(e))
90
-
91
-
92
- # @app.get("/")
93
- # def test():
94
- # return {"Hello": "World"}
 
7
  from fastapi.responses import JSONResponse
8
  from semantic_seg_model import segmentation_inference
9
  from similarity_inference import similarity_inference
 
10
  from gradio_client import Client, file
11
+ from datetime import datetime
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+
14
 
15
  app = FastAPI(docs_url="/")
16
 
17
+ allowed_origins = ["*"]
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=allowed_origins,
21
+ allow_credentials=True,
22
+ allow_methods=["GET", "POST", "PUT", "DELETE"],
23
+ allow_headers=["*"],
24
+ )
25
+
26
  ## Initialize the pipeline
27
  input_images_dir = "image/"
28
  temp_processing_dir = input_images_dir + "processed/"
29
 
30
+ # Define a function to handle the POST request at `image-analyzer`
31
+ @app.post("/image-analyzer")
32
+ def image_analyzer(image: UploadFile = File(...)):
33
  """
34
  This function takes in an image filepath and will return the PolyHaven url addresses of the
35
  top k materials similar to the wall, ceiling, and floor.
36
  """
37
  try:
38
  # load image
39
+ image_path = os.path.join(input_images_dir, "image.png")
40
  with open(image_path, "wb") as buffer:
41
  shutil.copyfileobj(image.file, buffer)
42
  image = Image.open(image_path)
43
+ print("image loaded successfully. Processing image for segmentation and similarity inference...", datetime.now())
44
  # segment into components
45
  segmentation_inference(image, temp_processing_dir)
46
+ print("image segmented successfully. Starting similarity inference...", datetime.now())
47
  # identify similar materials for each component
48
+ matching_textures = similarity_inference(temp_processing_dir)
49
+ print("done", datetime.now())
50
 
51
  # Return the urls in a JSON response
52
+ return matching_textures
53
 
54
  except Exception as e:
55
  print(str(e))
 
59
  client = Client("MykolaL/StableDesign")
60
 
61
  @app.post("/image-render")
62
+ async def image_render(prompt: str, image: UploadFile = File(...)):
63
  """
64
  Makes a prediction using the "StableDesign" model hosted on a server.
65
 
 
67
  The prediction result.
68
  """
69
  try:
70
+ print(f"recieved prompt: {prompt} and image: {image}")
71
+ image_path = os.path.join(input_images_dir, image.filename+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")+".png")
72
+ contents = await image.read()
73
+ # Remove the prefix "data:image/png;base64,"
74
+ image_data = contents.split(b";base64,")[1]
75
+ # Decode base64 data
76
+ decoded_image = base64.b64decode(image_data)
77
+ image = Image.open(io.BytesIO(decoded_image))
78
+ # Convert image to grayscale
79
+ grayscale_image = image.convert('L')
80
+ # Save the processed image to the specified path
81
+ grayscale_image.save(image_path)
82
  result = client.predict(
83
  image=file(image_path),
84
  text=prompt,
85
  num_steps=50,
86
  guidance_scale=10,
87
  seed=1111664444,
88
+ strength=1,
89
  a_prompt="interior design, 4K, high resolution, photorealistic",
90
  n_prompt="window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner",
91
  img_size=768,
92
  api_name="/on_submit"
93
  )
94
+
95
+ new_image_path = result
96
+ if not os.path.exists(new_image_path):
97
  raise HTTPException(status_code=404, detail="Image not found")
98
 
99
  # Open the image file and convert it to base64
100
+ with open(new_image_path, "rb") as img_file:
101
  base64_str = base64.b64encode(img_file.read()).decode('utf-8')
102
+
103
  return JSONResponse(content={"image": base64_str}, status_code=200)
104
  except Exception as e:
105
  print(str(e))
106
  raise HTTPException(status_code=500, detail=str(e))
107
+