cocktailpeanut commited on
Commit
ecf9fe8
1 Parent(s): e96dd77
Files changed (2) hide show
  1. app.py +13 -3
  2. requirements.txt +2 -2
app.py CHANGED
@@ -13,9 +13,12 @@ from PIL import Image
13
 
14
  import sf3d.utils as sf3d_utils
15
  from sf3d.system import SF3D
 
16
 
17
  rembg_session = rembg.new_session()
18
 
 
 
19
  COND_WIDTH = 512
20
  COND_HEIGHT = 512
21
  COND_DISTANCE = 1.6
@@ -34,7 +37,8 @@ model = SF3D.from_pretrained(
34
  config_name="config.yaml",
35
  weight_name="model.safetensors",
36
  )
37
- model.eval().cuda()
 
38
 
39
  example_files = [
40
  os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
@@ -44,9 +48,15 @@ example_files = [
44
  def run_model(input_image):
45
  start = time.time()
46
  with torch.no_grad():
47
- with torch.autocast(device_type="cuda", dtype=torch.float16):
 
 
 
 
 
 
48
  model_batch = create_batch(input_image)
49
- model_batch = {k: v.cuda() for k, v in model_batch.items()}
50
  trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
51
  trimesh_mesh = trimesh_mesh[0]
52
 
 
13
 
14
  import sf3d.utils as sf3d_utils
15
  from sf3d.system import SF3D
16
+ import devicetorch
17
 
18
  rembg_session = rembg.new_session()
19
 
20
+ DEVICE = devicetorch.get(torch)
21
+
22
  COND_WIDTH = 512
23
  COND_HEIGHT = 512
24
  COND_DISTANCE = 1.6
 
37
  config_name="config.yaml",
38
  weight_name="model.safetensors",
39
  )
40
+ #model.eval().cuda()
41
+ model.eval().to(device)
42
 
43
  example_files = [
44
  os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
 
48
  def run_model(input_image):
49
  start = time.time()
50
  with torch.no_grad():
51
+ if DEVICE == "cuda":
52
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
53
+ model_batch = create_batch(input_image)
54
+ model_batch = {k: v.cuda() for k, v in model_batch.items()}
55
+ trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
56
+ trimesh_mesh = trimesh_mesh[0]
57
+ else:
58
  model_batch = create_batch(input_image)
59
+ model_batch = {k: v.to(DEVICE) for k, v in model_batch.items()}
60
  trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
61
  trimesh_mesh = trimesh_mesh[0]
62
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- torch==2.1.2
2
- torchvision==0.16.2
3
  einops==0.7.0
4
  jaxtyping==0.2.31
5
  omegaconf==2.3.0
 
1
+ #torch==2.1.2
2
+ #torchvision==0.16.2
3
  einops==0.7.0
4
  jaxtyping==0.2.31
5
  omegaconf==2.3.0