|
""" |
|
Unit tests for the monkeypatch utils |
|
""" |
|
import unittest |
|
|
|
import torch |
|
|
|
from axolotl.monkeypatch.utils import ( |
|
get_cu_seqlens, |
|
get_cu_seqlens_from_pos_ids, |
|
get_max_seqlen_in_batch, |
|
get_unpad_data, |
|
) |
|
|
|
|
|
class TestMonkeyPatchUtils(unittest.TestCase): |
|
""" |
|
Unit test class for monkeypatch utils |
|
""" |
|
|
|
def test_get_cu_seqlens_1d(self): |
|
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]]) |
|
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32) |
|
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res)) |
|
|
|
def test_get_cu_seqlens_from_pos_ids_1d(self): |
|
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]]) |
|
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32) |
|
self.assertTrue( |
|
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res) |
|
) |
|
|
|
def test_get_cu_seqlens_from_pos_ids_2d(self): |
|
position_ids = torch.tensor( |
|
[ |
|
[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0], |
|
[0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0], |
|
] |
|
) |
|
target_res = torch.tensor( |
|
[[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32 |
|
) |
|
self.assertTrue( |
|
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res) |
|
) |
|
|
|
def test_get_max_seqlen_in_batch(self): |
|
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]]) |
|
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32) |
|
self.assertTrue(torch.allclose(get_max_seqlen_in_batch(attn_mask), target_res)) |
|
|
|
def test_get_unpad_data(self): |
|
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]]) |
|
target_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]) |
|
target_cu_seqlen = torch.tensor([0, 4, 7, 12, 14], dtype=torch.int32) |
|
target_max_seqlen_in_batch = 5 |
|
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask) |
|
self.assertTrue(torch.allclose(target_indices, indices)) |
|
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen)) |
|
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch) |
|
|
|
attn_mask = torch.tensor( |
|
[ |
|
[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0], |
|
[1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5], |
|
] |
|
) |
|
target_indices = torch.tensor( |
|
[ |
|
0, |
|
1, |
|
2, |
|
3, |
|
4, |
|
5, |
|
6, |
|
7, |
|
8, |
|
9, |
|
10, |
|
11, |
|
12, |
|
13, |
|
16, |
|
17, |
|
18, |
|
19, |
|
20, |
|
21, |
|
22, |
|
23, |
|
24, |
|
25, |
|
26, |
|
27, |
|
28, |
|
29, |
|
30, |
|
31, |
|
] |
|
) |
|
target_cu_seqlen = torch.tensor( |
|
[0, 4, 7, 12, 14, 17, 22, 24, 27, 30], dtype=torch.int32 |
|
) |
|
target_max_seqlen_in_batch = 5 |
|
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask) |
|
self.assertTrue(torch.allclose(target_indices, indices)) |
|
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen)) |
|
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|