Spaces:
Running
on
Zero
Running
on
Zero
robin-courant
commited on
Commit
•
85ea87c
1
Parent(s):
c008abb
Update app.py
Browse files
app.py
CHANGED
@@ -88,18 +88,22 @@ def get_normals(vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
|
|
88 |
|
89 |
return normals
|
90 |
|
|
|
91 |
@spaces.GPU
|
92 |
def generate(
|
93 |
prompt: str,
|
94 |
seed: int,
|
95 |
guidance_weight: float,
|
96 |
sample_label: str,
|
97 |
-
# -----------------------
|
98 |
dataset: MultimodalDataset,
|
99 |
device: torch.device,
|
100 |
diffuser: Diffuser,
|
101 |
clip_model: clip.model.CLIP,
|
102 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
103 |
# Set arguments
|
104 |
set_random_seed(seed)
|
105 |
diffuser.gen_seeds = np.array([seed])
|
@@ -206,9 +210,6 @@ def launch_app(gen_fn: Callable):
|
|
206 |
# ------------------------------------------------------------------------------------- #
|
207 |
|
208 |
diffuser, clip_model, dataset, device = init("config")
|
209 |
-
diffuser.to("cuda")
|
210 |
-
clip_model.to("cuda")
|
211 |
-
|
212 |
generate_sample = partial(
|
213 |
generate,
|
214 |
dataset=dataset,
|
@@ -216,5 +217,4 @@ generate_sample = partial(
|
|
216 |
diffuser=diffuser,
|
217 |
clip_model=clip_model,
|
218 |
)
|
219 |
-
|
220 |
launch_app(generate_sample)
|
|
|
88 |
|
89 |
return normals
|
90 |
|
91 |
+
|
92 |
@spaces.GPU
|
93 |
def generate(
|
94 |
prompt: str,
|
95 |
seed: int,
|
96 |
guidance_weight: float,
|
97 |
sample_label: str,
|
98 |
+
# ----------------------- #
|
99 |
dataset: MultimodalDataset,
|
100 |
device: torch.device,
|
101 |
diffuser: Diffuser,
|
102 |
clip_model: clip.model.CLIP,
|
103 |
) -> Dict[str, Any]:
|
104 |
+
diffuser.to(device)
|
105 |
+
clip_model.to(device)
|
106 |
+
|
107 |
# Set arguments
|
108 |
set_random_seed(seed)
|
109 |
diffuser.gen_seeds = np.array([seed])
|
|
|
210 |
# ------------------------------------------------------------------------------------- #
|
211 |
|
212 |
diffuser, clip_model, dataset, device = init("config")
|
|
|
|
|
|
|
213 |
generate_sample = partial(
|
214 |
generate,
|
215 |
dataset=dataset,
|
|
|
217 |
diffuser=diffuser,
|
218 |
clip_model=clip_model,
|
219 |
)
|
|
|
220 |
launch_app(generate_sample)
|