Spaces:
Paused
Paused
import torch | |
import torch.nn as nn | |
from mono.utils.comm import get_func | |
class BaseDepthModel(nn.Module): | |
def __init__(self, cfg, **kwargs) -> None: | |
super(BaseDepthModel, self).__init__() | |
model_type = cfg.model.type | |
self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg) | |
def forward(self, data): | |
output = self.depth_model(**data) | |
return output['prediction'], output['confidence'], output | |
def inference(self, data): | |
with torch.no_grad(): | |
pred_depth, confidence, _ = self.forward(data) | |
return pred_depth, confidence |