File size: 3,704 Bytes
2bb0b78 6910e6a 2bb0b78 00568c1 6910e6a 2bb0b78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
"""
Unit tests for the monkeypatch utils
"""
import unittest
import torch
from axolotl.monkeypatch.utils import (
get_cu_seqlens,
get_cu_seqlens_from_pos_ids,
get_max_seqlen_in_batch,
get_unpad_data,
)
class TestMonkeyPatchUtils(unittest.TestCase):
"""
Unit test class for monkeypatch utils
"""
def test_get_cu_seqlens_1d(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))
def test_get_cu_seqlens_from_pos_ids_1d(self):
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
self.assertTrue(
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
def test_get_cu_seqlens_from_pos_ids_2d(self):
position_ids = torch.tensor(
[
[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0],
[0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0],
]
)
target_res = torch.tensor(
[[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32
)
self.assertTrue(
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
def test_get_max_seqlen_in_batch(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
self.assertTrue(torch.allclose(get_max_seqlen_in_batch(attn_mask), target_res))
def test_get_unpad_data(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
target_cu_seqlen = torch.tensor([0, 4, 7, 12, 14], dtype=torch.int32)
target_max_seqlen_in_batch = 5
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
self.assertTrue(torch.allclose(target_indices, indices))
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
attn_mask = torch.tensor(
[
[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0],
[1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5],
]
)
target_indices = torch.tensor(
[
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
]
)
target_cu_seqlen = torch.tensor(
[0, 4, 7, 12, 14, 17, 22, 24, 27, 30], dtype=torch.int32
)
target_max_seqlen_in_batch = 5
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
self.assertTrue(torch.allclose(target_indices, indices))
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
if __name__ == "__main__":
unittest.main()
|