|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Accelerate utilities: Utilities related to accelerate |
|
""" |
|
|
|
from packaging import version |
|
|
|
from .import_utils import is_accelerate_available |
|
|
|
|
|
if is_accelerate_available(): |
|
import accelerate |
|
|
|
|
|
def apply_forward_hook(method): |
|
""" |
|
Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful |
|
for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the |
|
appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. |
|
|
|
This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. |
|
|
|
:param method: The method to decorate. This method should be a method of a PyTorch module. |
|
""" |
|
if not is_accelerate_available(): |
|
return method |
|
accelerate_version = version.parse(accelerate.__version__).base_version |
|
if version.parse(accelerate_version) < version.parse("0.17.0"): |
|
return method |
|
|
|
def wrapper(self, *args, **kwargs): |
|
if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): |
|
self._hf_hook.pre_forward(self) |
|
return method(self, *args, **kwargs) |
|
|
|
return wrapper |
|
|