Update modeling_hymba.py
Browse files- modeling_hymba.py +0 -70
modeling_hymba.py
CHANGED
@@ -1714,76 +1714,6 @@ class HymbaBlock(nn.Module):
|
|
1714 |
|
1715 |
if ssm_state is not None and cache_params is not None:
|
1716 |
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
1717 |
-
# if use_precomputed_states and self.layer_idx==31:
|
1718 |
-
# # except Exception as e:
|
1719 |
-
# print("\n\n\n\n")
|
1720 |
-
# # print(e)
|
1721 |
-
# print(f"use_precomputed_states {use_precomputed_states}; {index} {self.D}, {time_proj_bias} ")
|
1722 |
-
# print(f"{self.D[index]} ")
|
1723 |
-
# # cache_params.ssm_states[self.layer_idx],
|
1724 |
-
# # hidden_states[..., 0],
|
1725 |
-
# # discrete_time_step[..., 0],
|
1726 |
-
# # A,
|
1727 |
-
# # B[:, 0],
|
1728 |
-
# # C[:, 0],
|
1729 |
-
# # self.D[index],
|
1730 |
-
# # gate[..., 0],
|
1731 |
-
# # time_proj_bias,
|
1732 |
-
# print("=== Variable Values ===")
|
1733 |
-
# try:
|
1734 |
-
# print(f"cache_params.ssm_states[{self.layer_idx}]: {cache_params.ssm_states[self.layer_idx]}")
|
1735 |
-
# print(f"{cache_params.ssm_states[self.layer_idx].shape}")
|
1736 |
-
# except Exception as e:
|
1737 |
-
# print(f"Error accessing cache_params.ssm_states[{self.layer_idx}]: {e}")
|
1738 |
-
|
1739 |
-
# try:
|
1740 |
-
# print(f"hidden_states[..., 0]: {hidden_states[..., 0]}")
|
1741 |
-
# print(f"hidden_states[..., 0] shape: {hidden_states[..., 0].shape}")
|
1742 |
-
# except Exception as e:
|
1743 |
-
# print(f"Error accessing hidden_states[..., 0]: {e}")
|
1744 |
-
|
1745 |
-
# try:
|
1746 |
-
# print(f"discrete_time_step[..., 0]: {discrete_time_step[..., 0]}")
|
1747 |
-
# print(f"discrete_time_step[..., 0].shape: {discrete_time_step[..., 0].shape}")
|
1748 |
-
# except Exception as e:
|
1749 |
-
# print(f"Error accessing discrete_time_step[..., 0]: {e}")
|
1750 |
-
|
1751 |
-
# try:
|
1752 |
-
# print(f"A: {A}")
|
1753 |
-
# print(f"A.shape: {A.shape}")
|
1754 |
-
# except Exception as e:
|
1755 |
-
# print(f"Error accessing A: {e}")
|
1756 |
-
|
1757 |
-
# try:
|
1758 |
-
# print(f"B[:, 0]: {B[:, 0].shape}")
|
1759 |
-
# print(f"B[:, 0].shape: {B[:, 0].shape}")
|
1760 |
-
# except Exception as e:
|
1761 |
-
# print(f"Error accessing B[:, 0]: {e}")
|
1762 |
-
|
1763 |
-
# try:
|
1764 |
-
# print(f"C[:, 0]: {C[:, 0]}")
|
1765 |
-
# print(f"C[:, 0].shape: {C[:, 0].shape}")
|
1766 |
-
# except Exception as e:
|
1767 |
-
# print(f"Error accessing C[:, 0]: {e}")
|
1768 |
-
|
1769 |
-
# try:
|
1770 |
-
# print(f"D[index]: {self.D[index]}")
|
1771 |
-
# print(f"D[index].shape: {self.D[index].shape}")
|
1772 |
-
# except Exception as e:
|
1773 |
-
# print(f"Error accessing D[{index}]: {e}")
|
1774 |
-
|
1775 |
-
# try:
|
1776 |
-
# print(f"gate[..., 0]: {gate[..., 0]}")
|
1777 |
-
# print(f"gate[..., 0].shape: {gate[..., 0].shape}")
|
1778 |
-
# except Exception as e:
|
1779 |
-
# print(f"Error accessing gate[..., 0]: {e}")
|
1780 |
-
|
1781 |
-
# try:
|
1782 |
-
# print(f"time_proj_bias: {time_proj_bias}")
|
1783 |
-
# except Exception as e:
|
1784 |
-
# print(f"Error accessing time_proj_bias: {e}")
|
1785 |
-
|
1786 |
-
# print("\n\n\n\n")
|
1787 |
|
1788 |
scan_outputs = scan_outputs.transpose(1, 2)
|
1789 |
|
|
|
1714 |
|
1715 |
if ssm_state is not None and cache_params is not None:
|
1716 |
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1717 |
|
1718 |
scan_outputs = scan_outputs.transpose(1, 2)
|
1719 |
|