File size: 1,568 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from abc import ABC, abstractmethod
from typing import Dict, List
from torch import Tensor
from torch.nn import Module
from tha3.compute.cached_computation_func import TensorCachedComputationFunc, TensorListCachedComputationFunc
class CachedComputationProtocol(ABC):
def get_output(self,
key: str,
modules: Dict[str, Module],
batch: List[Tensor],
outputs: Dict[str, List[Tensor]]):
if key in outputs:
return outputs[key]
else:
output = self.compute_output(key, modules, batch, outputs)
outputs[key] = output
return outputs[key]
@abstractmethod
def compute_output(self,
key: str,
modules: Dict[str, Module],
batch: List[Tensor],
outputs: Dict[str, List[Tensor]]) -> List[Tensor]:
pass
def get_output_tensor_func(self, key: str, index: int) -> TensorCachedComputationFunc:
def func(modules: Dict[str, Module],
batch: List[Tensor],
outputs: Dict[str, List[Tensor]]):
return self.get_output(key, modules, batch, outputs)[index]
return func
def get_output_tensor_list_func(self, key: str) -> TensorListCachedComputationFunc:
def func(modules: Dict[str, Module],
batch: List[Tensor],
outputs: Dict[str, List[Tensor]]):
return self.get_output(key, modules, batch, outputs)
return func |