|
from typing import * |
|
import torch |
|
import torch.nn as nn |
|
from .. import models |
|
|
|
|
|
class Pipeline: |
|
""" |
|
A base class for pipelines. |
|
""" |
|
def __init__( |
|
self, |
|
models: dict[str, nn.Module] = None, |
|
): |
|
if models is None: |
|
return |
|
self.models = models |
|
for model in self.models.values(): |
|
model.eval() |
|
|
|
@staticmethod |
|
def from_pretrained(path: str) -> "Pipeline": |
|
""" |
|
Load a pretrained model. |
|
""" |
|
import os |
|
import json |
|
is_local = os.path.exists(f"{path}/pipeline.json") |
|
|
|
if is_local: |
|
config_file = f"{path}/pipeline.json" |
|
else: |
|
from huggingface_hub import hf_hub_download |
|
config_file = hf_hub_download(path, "pipeline.json") |
|
|
|
with open(config_file, 'r') as f: |
|
args = json.load(f)['args'] |
|
|
|
_models = { |
|
k: models.from_pretrained(f"{path}/{v}") |
|
for k, v in args['models'].items() |
|
} |
|
|
|
new_pipeline = Pipeline(_models) |
|
new_pipeline._pretrained_args = args |
|
return new_pipeline |
|
|
|
@property |
|
def device(self) -> torch.device: |
|
for model in self.models.values(): |
|
if hasattr(model, 'device'): |
|
return model.device |
|
for model in self.models.values(): |
|
if hasattr(model, 'parameters'): |
|
return next(model.parameters()).device |
|
raise RuntimeError("No device found.") |
|
|
|
def to(self, device: torch.device) -> None: |
|
for model in self.models.values(): |
|
model.to(device) |
|
|
|
def cuda(self) -> None: |
|
self.to(torch.device("cuda")) |
|
|
|
def cpu(self) -> None: |
|
self.to(torch.device("cpu")) |
|
|