File size: 1,119 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import copy
import torch
from easydict import EasyDict
from ding.utils import import_module, MODEL_REGISTRY


def create_model(cfg: EasyDict) -> torch.nn.Module:
    """
    Overview:
        Create a neural network model according to the given EasyDict-type ``cfg``.
    Arguments:
        - cfg: (:obj:`EasyDict`): User's model config. The key ``import_name`` is \
            used to import modules, and they key ``type`` is used to indicate the model.
    Returns:
        - (:obj:`torch.nn.Module`): The created neural network model.
    Examples:
        >>> cfg = EasyDict({
        >>>     'import_names': ['ding.model.template.q_learning'],
        >>>     'type': 'dqn',
        >>>     'obs_shape': 4,
        >>>     'action_shape': 2,
        >>> })
        >>> model = create_model(cfg)

    .. tip::
        This method will not modify the ``cfg`` , it will deepcopy the ``cfg`` and then modify it.
    """
    cfg = copy.deepcopy(cfg)
    import_module(cfg.pop('import_names', []))
    # here we must use the pop opeartion to ensure compatibility
    return MODEL_REGISTRY.build(cfg.pop("type"), **cfg)