from typing import Union, List import torch def is_differentiable( loss: torch.Tensor, model: Union[torch.nn.Module, List[torch.nn.Module]], print_instead: bool = False ) -> None: """ Overview: Judge whether the model/models are differentiable. First check whether module's grad is None, then do loss's back propagation, finally check whether module's grad are torch.Tensor. Arguments: - loss (:obj:`torch.Tensor`): loss tensor of the model - model (:obj:`Union[torch.nn.Module, List[torch.nn.Module]]`): model or models to be checked - print_instead (:obj:`bool`): Whether to print module's final grad result, \ instead of asserting. Default set to ``False``. """ assert isinstance(loss, torch.Tensor) if isinstance(model, list): for m in model: assert isinstance(m, torch.nn.Module) for k, p in m.named_parameters(): assert p.grad is None, k elif isinstance(model, torch.nn.Module): for k, p in model.named_parameters(): assert p.grad is None, k else: raise TypeError('model must be list or nn.Module') loss.backward() if isinstance(model, list): for m in model: for k, p in m.named_parameters(): if print_instead: if not isinstance(p.grad, torch.Tensor): print(k, "grad is:", p.grad) else: assert isinstance(p.grad, torch.Tensor), k elif isinstance(model, torch.nn.Module): for k, p in model.named_parameters(): if print_instead: if not isinstance(p.grad, torch.Tensor): print(k, "grad is:", p.grad) else: assert isinstance(p.grad, torch.Tensor), k