Spaces:
Build error
Build error
File size: 514 Bytes
3304f7d ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 3304f7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from typing import Dict
import numpy as np
import torch
def run_assertion(
orig_pt_state_dict: Dict[str, torch.Tensor],
pt_state_dict_from_tf: Dict[str, torch.Tensor],
):
for k in orig_pt_state_dict:
try:
np.testing.assert_allclose(
orig_pt_state_dict[k].numpy(), pt_state_dict_from_tf[k].numpy()
)
except:
raise ValueError(
"There are problems in the parameter population process. Cannot proceed :("
)
|