HF-SillyTavern-Extras / talkinghead /tha3 /compute /cached_computation_func.py
TomatoCocotree
上传
6a62ffb
raw
history blame
329 Bytes
from typing import Callable, Dict, List
from torch import Tensor
from torch.nn import Module
TensorCachedComputationFunc = Callable[
[Dict[str, Module], List[Tensor], Dict[str, List[Tensor]]], Tensor]
TensorListCachedComputationFunc = Callable[
[Dict[str, Module], List[Tensor], Dict[str, List[Tensor]]], List[Tensor]]