robin-courant commited on
Commit
0d05803
1 Parent(s): 45aa913

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -17
app.py CHANGED
@@ -88,7 +88,6 @@ def get_normals(vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
88
 
89
  return normals
90
 
91
-
92
  @spaces.GPU
93
  def generate(
94
  prompt: str,
@@ -203,22 +202,18 @@ def launch_app(gen_fn: Callable):
203
  demo.queue().launch(share=False)
204
 
205
 
206
- def main():
207
- # Initialize the models and dataset
208
- diffuser, clip_model, dataset, device = init("config")
209
- generate_sample = partial(
210
- generate,
211
- dataset=dataset,
212
- device=device,
213
- diffuser=diffuser,
214
- clip_model=clip_model,
215
- )
216
-
217
- launch_app(generate_sample)
218
-
219
-
220
  # ------------------------------------------------------------------------------------- #
221
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
- if __name__ == "__main__":
224
- main()
 
88
 
89
  return normals
90
 
 
91
  @spaces.GPU
92
  def generate(
93
  prompt: str,
 
202
  demo.queue().launch(share=False)
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  # ------------------------------------------------------------------------------------- #
206
 
207
+ diffuser, clip_model, dataset, device = init("config")
208
+ diffuser.to(device)
209
+ clip_model.to(device)
210
+
211
+ generate_sample = partial(
212
+ generate,
213
+ dataset=dataset,
214
+ device=device,
215
+ diffuser=diffuser,
216
+ clip_model=clip_model,
217
+ )
218
 
219
+ launch_app(generate_sample)