Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Optional, Sequence, Union | |
import torch | |
from mmengine.registry import HOOKS | |
from .hook import Hook | |
DATA_BATCH = Optional[Union[dict, tuple, list]] | |
class EmptyCacheHook(Hook): | |
"""Releases all unoccupied cached GPU memory during the process of | |
training. | |
Args: | |
before_epoch (bool): Whether to release cache before an epoch. Defaults | |
to False. | |
after_epoch (bool): Whether to release cache after an epoch. Defaults | |
to True. | |
after_iter (bool): Whether to release cache after an iteration. | |
Defaults to False. | |
""" | |
priority = 'NORMAL' | |
def __init__(self, | |
before_epoch: bool = False, | |
after_epoch: bool = True, | |
after_iter: bool = False) -> None: | |
self._do_before_epoch = before_epoch | |
self._do_after_epoch = after_epoch | |
self._do_after_iter = after_iter | |
def _after_iter(self, | |
runner, | |
batch_idx: int, | |
data_batch: DATA_BATCH = None, | |
outputs: Optional[Union[dict, Sequence]] = None, | |
mode: str = 'train') -> None: | |
"""Empty cache after an iteration. | |
Args: | |
runner (Runner): The runner of the training process. | |
batch_idx (int): The index of the current batch in the loop. | |
data_batch (dict or tuple or list, optional): Data from dataloader. | |
outputs (dict or sequence, optional): Outputs from model. | |
mode (str): Current mode of runner. Defaults to 'train'. | |
""" | |
if self._do_after_iter: | |
torch.cuda.empty_cache() | |
def _before_epoch(self, runner, mode: str = 'train') -> None: | |
"""Empty cache before an epoch. | |
Args: | |
runner (Runner): The runner of the training process. | |
mode (str): Current mode of runner. Defaults to 'train'. | |
""" | |
if self._do_before_epoch: | |
torch.cuda.empty_cache() | |
def _after_epoch(self, runner, mode: str = 'train') -> None: | |
"""Empty cache after an epoch. | |
Args: | |
runner (Runner): The runner of the training process. | |
mode (str): Current mode of runner. Defaults to 'train'. | |
""" | |
if self._do_after_epoch: | |
torch.cuda.empty_cache() | |