nikigoli commited on
Commit
20bd81e
1 Parent(s): 3c12f0e

Removed GPU for building models and transforms

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -116,7 +116,6 @@ def get_device():
116
  return torch.device('cpu')
117
 
118
  # Get counting model.
119
- @spaces.GPU
120
  def build_model_and_transforms(args):
121
  normalize = T.Compose(
122
  [T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
@@ -137,7 +136,8 @@ def build_model_and_transforms(args):
137
  else:
138
  raise ValueError("Key {} can used by args only".format(k))
139
 
140
- device = torch.device(args.device)
 
141
  # fix the seed for reproducibility
142
  seed = 42
143
  torch.manual_seed(seed)
@@ -156,7 +156,6 @@ def build_model_and_transforms(args):
156
  model.load_state_dict(checkpoint, strict=False)
157
 
158
  model.eval()
159
- model.cpu()
160
 
161
  return model, data_transform
162
 
@@ -164,7 +163,7 @@ def build_model_and_transforms(args):
164
  parser = argparse.ArgumentParser("Counting Application", parents=[get_args_parser()])
165
  args = parser.parse_args()
166
 
167
- args.device = get_device()
168
  model, transform = build_model_and_transforms(args)
169
 
170
  examples = [
@@ -406,7 +405,7 @@ As shown earlier, there are 3 ways to specify the object to count: (1) with text
406
 
407
  with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", head="""<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=1">""") as demo:
408
  state = gr.State(value=[AppSteps.JUST_TEXT])
409
- device = gr.State(args.device)
410
  with gr.Tab("Tutorial"):
411
  with gr.Row():
412
  with gr.Column():
 
116
  return torch.device('cpu')
117
 
118
  # Get counting model.
 
119
  def build_model_and_transforms(args):
120
  normalize = T.Compose(
121
  [T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
 
136
  else:
137
  raise ValueError("Key {} can used by args only".format(k))
138
 
139
+ # Start with model on CPU.
140
+ args.device = "cpu"
141
  # fix the seed for reproducibility
142
  seed = 42
143
  torch.manual_seed(seed)
 
156
  model.load_state_dict(checkpoint, strict=False)
157
 
158
  model.eval()
 
159
 
160
  return model, data_transform
161
 
 
163
  parser = argparse.ArgumentParser("Counting Application", parents=[get_args_parser()])
164
  args = parser.parse_args()
165
 
166
+ device = get_device()
167
  model, transform = build_model_and_transforms(args)
168
 
169
  examples = [
 
405
 
406
  with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", head="""<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=1">""") as demo:
407
  state = gr.State(value=[AppSteps.JUST_TEXT])
408
+ device = gr.State(device)
409
  with gr.Tab("Tutorial"):
410
  with gr.Row():
411
  with gr.Column():