File size: 521 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch


def get_num_params(model: torch.nn.Module) -> int:
    """
    Overview:
        Return the number of parameters in the model.
    Arguments:
        - model (:obj:`torch.nn.Module`): The model object to calculate the parameter number.
    Returns:
        - n_params (:obj:`int`): The calculated number of parameters.
    Examples:
        >>> model = torch.nn.Linear(3, 5)
        >>> num = get_num_params(model)
        >>> assert num == 15
    """
    return sum(p.numel() for p in model.parameters())