z-uo commited on
Commit
bd930f4
1 Parent(s): fe298bb

remove file after download button creation and scale image input

Browse files
Files changed (3) hide show
  1. app.py +14 -7
  2. model.obj +0 -0
  3. tmp.png +0 -0
app.py CHANGED
@@ -4,15 +4,14 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from torchvision import models
7
- from torchvision.transforms import ToTensor
8
 
9
  import numpy as np
10
  from PIL import Image
11
  import math
12
  from obj2html import obj2html
13
 
14
- from io import BytesIO
15
- import base64
16
 
17
  # DEPTH IMAGE TO OBJ
18
  minDepth=10
@@ -153,7 +152,13 @@ model.load_state_dict(torch.hub.load_state_dict_from_url(path, progress=True))
153
  model.eval()
154
 
155
  def predict(inp):
156
- torch_image = ToTensor()(inp)
 
 
 
 
 
 
157
  images = torch_image.unsqueeze(0)
158
 
159
  with torch.no_grad():
@@ -166,7 +171,7 @@ def predict(inp):
166
  create_obj(depth, 'model.obj')
167
  html_string = obj2html('model.obj', html_elements_only=True)
168
 
169
- return img, html_string
170
 
171
 
172
  # STREAMLIT
@@ -175,19 +180,21 @@ uploader = st.file_uploader('Upload your portrait here',type=['jpg','jpeg','png'
175
 
176
  if uploader is not None:
177
  pil_image = Image.open(uploader)
178
- pil_depth, html_string = predict(pil_image)
179
 
180
  components.html(html_string)
181
  #st.markdown(html_string, unsafe_allow_html=True)
182
 
183
  col1, col2, col3 = st.columns(3)
184
  with col1:
185
- st.image(pil_image)
186
  with col2:
187
  st.image(pil_depth)
188
  with col3:
189
  with open('model.obj') as f:
190
  st.download_button('Download model.obj', f, file_name="model.obj")
 
191
  pil_depth.save('tmp.png')
192
  with open('tmp.png', "rb") as f:
193
  st.download_button('Download depth.png', f,file_name="depth.png", mime="image/png")
 
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from torchvision import models
7
+ from torchvision.transforms import ToTensor, Resize
8
 
9
  import numpy as np
10
  from PIL import Image
11
  import math
12
  from obj2html import obj2html
13
 
14
+ import os
 
15
 
16
  # DEPTH IMAGE TO OBJ
17
  minDepth=10
 
152
  model.eval()
153
 
154
  def predict(inp):
155
+ width, height = inp.size
156
+ if width > height:
157
+ scale_fn = Resize((640, int((640*width)/height)))
158
+ else:
159
+ scale_fn = Resize((int((640*height)/width), 640))
160
+ res_img = scale_fn(inp)
161
+ torch_image = ToTensor()(res_img)
162
  images = torch_image.unsqueeze(0)
163
 
164
  with torch.no_grad():
 
171
  create_obj(depth, 'model.obj')
172
  html_string = obj2html('model.obj', html_elements_only=True)
173
 
174
+ return res_img, img, html_string
175
 
176
 
177
  # STREAMLIT
 
180
 
181
  if uploader is not None:
182
  pil_image = Image.open(uploader)
183
+ pil_scaled, pil_depth, html_string = predict(pil_image)
184
 
185
  components.html(html_string)
186
  #st.markdown(html_string, unsafe_allow_html=True)
187
 
188
  col1, col2, col3 = st.columns(3)
189
  with col1:
190
+ st.image(pil_scaled)
191
  with col2:
192
  st.image(pil_depth)
193
  with col3:
194
  with open('model.obj') as f:
195
  st.download_button('Download model.obj', f, file_name="model.obj")
196
+ os.remove('model.obj')
197
  pil_depth.save('tmp.png')
198
  with open('tmp.png', "rb") as f:
199
  st.download_button('Download depth.png', f,file_name="depth.png", mime="image/png")
200
+ os.remove('tmp.png')
model.obj DELETED
The diff for this file is too large to render. See raw diff
 
tmp.png DELETED
Binary file (9.25 kB)