|
|
|
|
|
|
|
|
|
|
|
import uuid |
|
from typing import Dict, Optional |
|
|
|
from torch import Tensor |
|
|
|
|
|
class FairseqIncrementalState(object): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.init_incremental_state() |
|
|
|
def init_incremental_state(self): |
|
self._incremental_state_id = str(uuid.uuid4()) |
|
|
|
def _get_full_incremental_state_key(self, key: str) -> str: |
|
return "{}.{}".format(self._incremental_state_id, key) |
|
|
|
def get_incremental_state( |
|
self, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
|
key: str, |
|
) -> Optional[Dict[str, Optional[Tensor]]]: |
|
"""Helper for getting incremental state for an nn.Module.""" |
|
full_key = self._get_full_incremental_state_key(key) |
|
if incremental_state is None or full_key not in incremental_state: |
|
return None |
|
return incremental_state[full_key] |
|
|
|
def set_incremental_state( |
|
self, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
|
key: str, |
|
value: Dict[str, Optional[Tensor]], |
|
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: |
|
"""Helper for setting incremental state for an nn.Module.""" |
|
if incremental_state is not None: |
|
full_key = self._get_full_incremental_state_key(key) |
|
incremental_state[full_key] = value |
|
return incremental_state |
|
|
|
|
|
def with_incremental_state(cls): |
|
cls.__bases__ = (FairseqIncrementalState,) + tuple( |
|
b for b in cls.__bases__ if b != FairseqIncrementalState |
|
) |
|
return cls |
|
|