JeffreyXiang
commited on
Commit
•
690b53e
1
Parent(s):
bd46f72
Speed UP!
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import spaces
|
|
3 |
from gradio_litmodel3d import LitModel3D
|
4 |
|
5 |
import os
|
|
|
6 |
from typing import *
|
7 |
import torch
|
8 |
import numpy as np
|
@@ -131,7 +132,7 @@ def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[s
|
|
131 |
str: The path to the extracted GLB file.
|
132 |
"""
|
133 |
gs, mesh, model_id = unpack_state(state)
|
134 |
-
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size)
|
135 |
glb_path = f"/tmp/Trellis-demo/{model_id}.glb"
|
136 |
glb.export(glb_path)
|
137 |
return glb_path, glb_path
|
@@ -161,12 +162,12 @@ with gr.Blocks() as demo:
|
|
161 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
162 |
gr.Markdown("Stage 1: Sparse Structure Generation")
|
163 |
with gr.Row():
|
164 |
-
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=5
|
165 |
-
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=
|
166 |
gr.Markdown("Stage 2: Structured Latent Generation")
|
167 |
with gr.Row():
|
168 |
-
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=
|
169 |
-
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=
|
170 |
|
171 |
generate_btn = gr.Button("Generate")
|
172 |
|
|
|
3 |
from gradio_litmodel3d import LitModel3D
|
4 |
|
5 |
import os
|
6 |
+
os.environ['SPCONV_ALGO'] = 'native'
|
7 |
from typing import *
|
8 |
import torch
|
9 |
import numpy as np
|
|
|
132 |
str: The path to the extracted GLB file.
|
133 |
"""
|
134 |
gs, mesh, model_id = unpack_state(state)
|
135 |
+
glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
|
136 |
glb_path = f"/tmp/Trellis-demo/{model_id}.glb"
|
137 |
glb.export(glb_path)
|
138 |
return glb_path, glb_path
|
|
|
162 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
163 |
gr.Markdown("Stage 1: Sparse Structure Generation")
|
164 |
with gr.Row():
|
165 |
+
ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
|
166 |
+
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
167 |
gr.Markdown("Stage 2: Structured Latent Generation")
|
168 |
with gr.Row():
|
169 |
+
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
|
170 |
+
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
171 |
|
172 |
generate_btn = gr.Button("Generate")
|
173 |
|
trellis/modules/sparse/__init__.py
CHANGED
@@ -24,6 +24,8 @@ def __from_env():
|
|
24 |
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
|
25 |
ATTN = env_sparse_attn
|
26 |
|
|
|
|
|
27 |
|
28 |
__from_env()
|
29 |
|
|
|
24 |
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
|
25 |
ATTN = env_sparse_attn
|
26 |
|
27 |
+
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
|
28 |
+
|
29 |
|
30 |
__from_env()
|
31 |
|
trellis/modules/sparse/conv/__init__.py
CHANGED
@@ -1,6 +1,21 @@
|
|
1 |
from .. import BACKEND
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
if BACKEND == 'torchsparse':
|
4 |
from .conv_torchsparse import *
|
5 |
elif BACKEND == 'spconv':
|
6 |
-
from .conv_spconv import *
|
|
|
1 |
from .. import BACKEND
|
2 |
|
3 |
+
|
4 |
+
SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
|
5 |
+
|
6 |
+
def __from_env():
|
7 |
+
import os
|
8 |
+
|
9 |
+
global SPCONV_ALGO
|
10 |
+
env_spconv_algo = os.environ.get('SPCONV_ALGO')
|
11 |
+
if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
|
12 |
+
SPCONV_ALGO = env_spconv_algo
|
13 |
+
print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
|
14 |
+
|
15 |
+
|
16 |
+
__from_env()
|
17 |
+
|
18 |
if BACKEND == 'torchsparse':
|
19 |
from .conv_torchsparse import *
|
20 |
elif BACKEND == 'spconv':
|
21 |
+
from .conv_spconv import *
|
trellis/modules/sparse/conv/conv_spconv.py
CHANGED
@@ -2,16 +2,22 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
from .. import SparseTensor
|
4 |
from .. import DEBUG
|
|
|
5 |
|
6 |
class SparseConv3d(nn.Module):
|
7 |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
|
8 |
super(SparseConv3d, self).__init__()
|
9 |
if 'spconv' not in globals():
|
10 |
import spconv.pytorch as spconv
|
|
|
|
|
|
|
|
|
|
|
11 |
if stride == 1 and (padding is None):
|
12 |
-
self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key)
|
13 |
else:
|
14 |
-
self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key)
|
15 |
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
|
16 |
self.padding = padding
|
17 |
|
|
|
2 |
import torch.nn as nn
|
3 |
from .. import SparseTensor
|
4 |
from .. import DEBUG
|
5 |
+
from . import SPCONV_ALGO
|
6 |
|
7 |
class SparseConv3d(nn.Module):
|
8 |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
|
9 |
super(SparseConv3d, self).__init__()
|
10 |
if 'spconv' not in globals():
|
11 |
import spconv.pytorch as spconv
|
12 |
+
algo = None
|
13 |
+
if SPCONV_ALGO == 'native':
|
14 |
+
algo = spconv.ConvAlgo.Native
|
15 |
+
elif SPCONV_ALGO == 'implicit_gemm':
|
16 |
+
algo = spconv.ConvAlgo.MaskImplicitGemm
|
17 |
if stride == 1 and (padding is None):
|
18 |
+
self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
|
19 |
else:
|
20 |
+
self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
|
21 |
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
|
22 |
self.padding = padding
|
23 |
|
trellis/utils/postprocessing_utils.py
CHANGED
@@ -448,7 +448,7 @@ def to_glb(
|
|
448 |
observations, masks, extrinsics, intrinsics,
|
449 |
texture_size=texture_size, mode='opt',
|
450 |
lambda_tv=0.01,
|
451 |
-
verbose=
|
452 |
)
|
453 |
texture = Image.fromarray(texture)
|
454 |
|
|
|
448 |
observations, masks, extrinsics, intrinsics,
|
449 |
texture_size=texture_size, mode='opt',
|
450 |
lambda_tv=0.01,
|
451 |
+
verbose=verbose
|
452 |
)
|
453 |
texture = Image.fromarray(texture)
|
454 |
|