liuyizhang
commited on
Commit
•
9988241
1
Parent(s):
bf71fd5
update app.py
Browse files
app.py
CHANGED
@@ -347,7 +347,7 @@ config = dict(
|
|
347 |
)
|
348 |
config = Config(config)
|
349 |
|
350 |
-
class
|
351 |
def __init__(self,config):
|
352 |
self.config = config
|
353 |
self.device = torch.device(device)
|
@@ -358,7 +358,7 @@ class Predictor(RamPredictor, device='cpu'):
|
|
358 |
if self.config.load_from is not None:
|
359 |
self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
|
360 |
self.model.train()
|
361 |
-
ram_model =
|
362 |
|
363 |
# visualization
|
364 |
def draw_selected_mask(mask, draw):
|
|
|
347 |
)
|
348 |
config = Config(config)
|
349 |
|
350 |
+
class Ram_Predictor(RamPredictor, device='cpu'):
|
351 |
def __init__(self,config):
|
352 |
self.config = config
|
353 |
self.device = torch.device(device)
|
|
|
358 |
if self.config.load_from is not None:
|
359 |
self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
|
360 |
self.model.train()
|
361 |
+
ram_model = Ram_Predictor(config, device)
|
362 |
|
363 |
# visualization
|
364 |
def draw_selected_mask(mask, draw):
|