File size: 456 Bytes
ae81e0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import numpy as np
def count_parameters(model, requires_grad: bool = True):
"""
Return total # of trainable parameters
"""
if requires_grad:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
else:
model_parameters = model.parameters()
try:
return sum([np.prod(p.size()) for p in model_parameters]).item()
except:
return sum([np.prod(p.size()) for p in model_parameters])
|