|
""" |
|
Unit tests for the monkeypatch utils |
|
""" |
|
import unittest |
|
|
|
import torch |
|
|
|
from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids |
|
|
|
|
|
class TestMonkeyPatchUtils(unittest.TestCase): |
|
""" |
|
Unit test class for monkeypatch utils |
|
""" |
|
|
|
def test_get_cu_seqlens_1d(self): |
|
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]]) |
|
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32) |
|
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res)) |
|
|
|
def test_get_cu_seqlens_from_pos_ids_1d(self): |
|
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]]) |
|
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32) |
|
self.assertTrue( |
|
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res) |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|