import numpy as np def count_parameters(model, requires_grad: bool = True): """ Return total # of trainable parameters """ if requires_grad: model_parameters = filter(lambda p: p.requires_grad, model.parameters()) else: model_parameters = model.parameters() try: return sum([np.prod(p.size()) for p in model_parameters]).item() except: return sum([np.prod(p.size()) for p in model_parameters])