import numpy as np | |
def count_parameters(model, requires_grad: bool = True): | |
""" | |
Return total # of trainable parameters | |
""" | |
if requires_grad: | |
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
else: | |
model_parameters = model.parameters() | |
try: | |
return sum([np.prod(p.size()) for p in model_parameters]).item() | |
except: | |
return sum([np.prod(p.size()) for p in model_parameters]) | |