import torch conv = torch.load("depthwise_conv.pt") h = torch.load("hidden_states.pt") o = conv(h) print(o)