will33am commited on
Commit
3ea0526
1 Parent(s): 909e4ad
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +2 -2
  2. app.py +2 -2
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -23,7 +23,7 @@ ARCH = 'crossvit_18_dagger_408'
23
  ARCH_PATH = './checkpoints/robust_crossvit_18_dagger_408.pt'
24
  CUSTOM_TRANSFORMS = transforms.Compose([transforms.Resize([IMG_MAX_SIZE,IMG_MAX_SIZE]),
25
  transforms.ToTensor()])
26
- DEVICE = 'cpu'
27
 
28
 
29
  def load_model(robust = True):
@@ -87,6 +87,6 @@ if __name__ == '__main__':
87
  calculate_button.click(fn = gradio_fn,
88
  inputs = [image_input,radio_steps,radio_class,radio_robust],
89
  outputs = target_image)
90
- demo.launch(share = True,debug = True)
91
 
92
 
 
23
  ARCH_PATH = './checkpoints/robust_crossvit_18_dagger_408.pt'
24
  CUSTOM_TRANSFORMS = transforms.Compose([transforms.Resize([IMG_MAX_SIZE,IMG_MAX_SIZE]),
25
  transforms.ToTensor()])
26
+ DEVICE = 'cuda'
27
 
28
 
29
  def load_model(robust = True):
 
87
  calculate_button.click(fn = gradio_fn,
88
  inputs = [image_input,radio_steps,radio_class,radio_robust],
89
  outputs = target_image)
90
+ demo.launch(debug = True)
91
 
92
 
app.py CHANGED
@@ -23,7 +23,7 @@ ARCH = 'crossvit_18_dagger_408'
23
  ARCH_PATH = './checkpoints/robust_crossvit_18_dagger_408.pt'
24
  CUSTOM_TRANSFORMS = transforms.Compose([transforms.Resize([IMG_MAX_SIZE,IMG_MAX_SIZE]),
25
  transforms.ToTensor()])
26
- DEVICE = 'cpu'
27
 
28
 
29
  def load_model(robust = True):
@@ -87,6 +87,6 @@ if __name__ == '__main__':
87
  calculate_button.click(fn = gradio_fn,
88
  inputs = [image_input,radio_steps,radio_class,radio_robust],
89
  outputs = target_image)
90
- demo.launch(share = True,debug = True)
91
 
92
 
 
23
  ARCH_PATH = './checkpoints/robust_crossvit_18_dagger_408.pt'
24
  CUSTOM_TRANSFORMS = transforms.Compose([transforms.Resize([IMG_MAX_SIZE,IMG_MAX_SIZE]),
25
  transforms.ToTensor()])
26
+ DEVICE = 'cuda'
27
 
28
 
29
  def load_model(robust = True):
 
87
  calculate_button.click(fn = gradio_fn,
88
  inputs = [image_input,radio_steps,radio_class,radio_robust],
89
  outputs = target_image)
90
+ demo.launch(debug = True)
91
 
92