|
from abc import ABC, abstractmethod |
|
from typing import Dict, List |
|
|
|
from torch import Tensor |
|
from torch.nn import Module |
|
|
|
from tha3.compute.cached_computation_func import TensorCachedComputationFunc, TensorListCachedComputationFunc |
|
|
|
|
|
class CachedComputationProtocol(ABC): |
|
def get_output(self, |
|
key: str, |
|
modules: Dict[str, Module], |
|
batch: List[Tensor], |
|
outputs: Dict[str, List[Tensor]]): |
|
if key in outputs: |
|
return outputs[key] |
|
else: |
|
output = self.compute_output(key, modules, batch, outputs) |
|
outputs[key] = output |
|
return outputs[key] |
|
|
|
@abstractmethod |
|
def compute_output(self, |
|
key: str, |
|
modules: Dict[str, Module], |
|
batch: List[Tensor], |
|
outputs: Dict[str, List[Tensor]]) -> List[Tensor]: |
|
pass |
|
|
|
def get_output_tensor_func(self, key: str, index: int) -> TensorCachedComputationFunc: |
|
def func(modules: Dict[str, Module], |
|
batch: List[Tensor], |
|
outputs: Dict[str, List[Tensor]]): |
|
return self.get_output(key, modules, batch, outputs)[index] |
|
return func |
|
|
|
def get_output_tensor_list_func(self, key: str) -> TensorListCachedComputationFunc: |
|
def func(modules: Dict[str, Module], |
|
batch: List[Tensor], |
|
outputs: Dict[str, List[Tensor]]): |
|
return self.get_output(key, modules, batch, outputs) |
|
return func |