Spaces:
Runtime error
Runtime error
Update predict.py
Browse files- predict.py +2 -2
predict.py
CHANGED
@@ -10,13 +10,13 @@ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
|
10 |
from detectron2.data import MetadataCatalog
|
11 |
from detectron2.utils.visualizer import ColorMode, Visualizer
|
12 |
from color_palette import ade_palette
|
13 |
-
from transformers import
|
14 |
|
15 |
def load_model_and_processor(model_ckpt: str):
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
|
18 |
model.eval()
|
19 |
-
image_preprocessor =
|
20 |
return model, image_preprocessor
|
21 |
|
22 |
def load_default_ckpt(segmentation_task: str):
|
|
|
10 |
from detectron2.data import MetadataCatalog
|
11 |
from detectron2.utils.visualizer import ColorMode, Visualizer
|
12 |
from color_palette import ade_palette
|
13 |
+
from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
|
14 |
|
15 |
def load_model_and_processor(model_ckpt: str):
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
|
18 |
model.eval()
|
19 |
+
image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt)
|
20 |
return model, image_preprocessor
|
21 |
|
22 |
def load_default_ckpt(segmentation_task: str):
|