import pytorch_lightning as pl from src.ss.det_models.model import POIDetection from src.ss.datasets_signboard_detection.datamodule import POIDataModule from src.ss.det_models.inference_signboard_detection import POIDetectionTask def load_model(checkpoint_path): model = POIDetection.load_from_checkpoint(checkpoint_path=checkpoint_path) return model def inference_signboard(image_path, checkpoint, score): dm = POIDataModule(data_path=image_path, seed=42) dm.setup("predict") model = load_model(checkpoint) from src.ss.det_models.inference_signboard_detection import POIDetectionTask task = POIDetectionTask(model, data_path=image_path, score=score) # accelerator='gpu', devices=1 trainer = pl.Trainer(gpus=1, max_epochs=-1) trainer.predict(task, datamodule=dm) return task.output class SignBoardDetector(): def __init__(self, checkpoint) -> None: self.model = POIDetection.load_from_checkpoint( checkpoint_path=checkpoint) def inference_signboard(self, image, score): dm = POIDataModule(data_path=image, seed=42) dm.setup("predict") task = POIDetectionTask(self.model, score=score) trainer = pl.Trainer(gpus=1, max_epochs=-1) trainer.predict(task, datamodule=dm) return task.output