Spaces:
Sleeping
Sleeping
Removed GPU for building models and transforms
Browse files
app.py
CHANGED
@@ -116,7 +116,6 @@ def get_device():
|
|
116 |
return torch.device('cpu')
|
117 |
|
118 |
# Get counting model.
|
119 |
-
@spaces.GPU
|
120 |
def build_model_and_transforms(args):
|
121 |
normalize = T.Compose(
|
122 |
[T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
|
@@ -137,7 +136,8 @@ def build_model_and_transforms(args):
|
|
137 |
else:
|
138 |
raise ValueError("Key {} can used by args only".format(k))
|
139 |
|
140 |
-
|
|
|
141 |
# fix the seed for reproducibility
|
142 |
seed = 42
|
143 |
torch.manual_seed(seed)
|
@@ -156,7 +156,6 @@ def build_model_and_transforms(args):
|
|
156 |
model.load_state_dict(checkpoint, strict=False)
|
157 |
|
158 |
model.eval()
|
159 |
-
model.cpu()
|
160 |
|
161 |
return model, data_transform
|
162 |
|
@@ -164,7 +163,7 @@ def build_model_and_transforms(args):
|
|
164 |
parser = argparse.ArgumentParser("Counting Application", parents=[get_args_parser()])
|
165 |
args = parser.parse_args()
|
166 |
|
167 |
-
|
168 |
model, transform = build_model_and_transforms(args)
|
169 |
|
170 |
examples = [
|
@@ -406,7 +405,7 @@ As shown earlier, there are 3 ways to specify the object to count: (1) with text
|
|
406 |
|
407 |
with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", head="""<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=1">""") as demo:
|
408 |
state = gr.State(value=[AppSteps.JUST_TEXT])
|
409 |
-
device = gr.State(
|
410 |
with gr.Tab("Tutorial"):
|
411 |
with gr.Row():
|
412 |
with gr.Column():
|
|
|
116 |
return torch.device('cpu')
|
117 |
|
118 |
# Get counting model.
|
|
|
119 |
def build_model_and_transforms(args):
|
120 |
normalize = T.Compose(
|
121 |
[T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
|
|
|
136 |
else:
|
137 |
raise ValueError("Key {} can used by args only".format(k))
|
138 |
|
139 |
+
# Start with model on CPU.
|
140 |
+
args.device = "cpu"
|
141 |
# fix the seed for reproducibility
|
142 |
seed = 42
|
143 |
torch.manual_seed(seed)
|
|
|
156 |
model.load_state_dict(checkpoint, strict=False)
|
157 |
|
158 |
model.eval()
|
|
|
159 |
|
160 |
return model, data_transform
|
161 |
|
|
|
163 |
parser = argparse.ArgumentParser("Counting Application", parents=[get_args_parser()])
|
164 |
args = parser.parse_args()
|
165 |
|
166 |
+
device = get_device()
|
167 |
model, transform = build_model_and_transforms(args)
|
168 |
|
169 |
examples = [
|
|
|
405 |
|
406 |
with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", head="""<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=1">""") as demo:
|
407 |
state = gr.State(value=[AppSteps.JUST_TEXT])
|
408 |
+
device = gr.State(device)
|
409 |
with gr.Tab("Tutorial"):
|
410 |
with gr.Row():
|
411 |
with gr.Column():
|