File size: 1,859 Bytes
3ab16a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# -*-coding:utf-8-*-
import torch
import numpy as np
from typing import List


def get_rep_pos(tokenized: torch.Tensor, rep_tokens: list):
    pos_list = []
    for token in rep_tokens:
        pos_list = torch.stack(torch.where(tokenized == token)).T.tolist()
    return pos_list


def shift_tensor_dim0(ori: torch.Tensor, r_pos: List[np.ndarray], reps: int):
    assert reps >= 1  
    device = ori.device
    d = ori.shape[0]
    offset = np.zeros(d, dtype=np.int64) 
    r_pos_cat = np.concatenate(r_pos)
    for p in r_pos_cat:
        offset[p + 1:] += (reps - 1)

    r_cnt = r_pos_cat.shape[0] 
    target_pos = (np.arange(d) + offset)[:d - r_cnt * (reps - 1)] 
    ori[target_pos] = ori[np.arange(target_pos.shape[0])]

    rep_final_pos: np.ndarray = target_pos[r_pos_cat].repeat(reps) + np.tile(np.arange(reps), r_cnt)
    ori[rep_final_pos] = ori[target_pos[r_pos_cat].repeat(reps)]

    rep_final_pos_list = []
    lo = 0
    for i in range(len(r_pos)):
        r_one_times = r_pos[i].shape[0]
        r_one_nums = r_one_times * reps 
        rep_final_pos_list.append(rep_final_pos[lo: lo + r_one_nums].reshape(r_one_times, reps))
        lo += r_one_nums
    return ori, rep_final_pos_list


def _test_get_rep_pos():
    tokenized = torch.LongTensor([0, 1, 2, 2, 3, 4, 5, 6, 7, 99] + [99] * 20)
    print('[from]:', tokenized)
    rep_tokens = [2, 6]
    rep_times = 2

    rep_pos = get_rep_pos(tokenized, rep_tokens)
    print('[rep_pos]:', rep_pos)
    res, rep_pos_final = shift_tensor_dim0(tokenized, rep_pos, rep_times)
    print('[to]:', res)
    print('[final pos]:', rep_pos_final)


def _test_shift_tensor_dim0():
    embedded = torch.arange(20)
    print(embedded)
    pos = np.array([3, 6, 8])
    times = 1
    output = shift_tensor_dim0(embedded, pos, times)
    print(output)


if __name__ == "__main__":
    _test_get_rep_pos()