Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,859 Bytes
3ab16a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
# -*-coding:utf-8-*-
import torch
import numpy as np
from typing import List
def get_rep_pos(tokenized: torch.Tensor, rep_tokens: list):
pos_list = []
for token in rep_tokens:
pos_list = torch.stack(torch.where(tokenized == token)).T.tolist()
return pos_list
def shift_tensor_dim0(ori: torch.Tensor, r_pos: List[np.ndarray], reps: int):
assert reps >= 1
device = ori.device
d = ori.shape[0]
offset = np.zeros(d, dtype=np.int64)
r_pos_cat = np.concatenate(r_pos)
for p in r_pos_cat:
offset[p + 1:] += (reps - 1)
r_cnt = r_pos_cat.shape[0]
target_pos = (np.arange(d) + offset)[:d - r_cnt * (reps - 1)]
ori[target_pos] = ori[np.arange(target_pos.shape[0])]
rep_final_pos: np.ndarray = target_pos[r_pos_cat].repeat(reps) + np.tile(np.arange(reps), r_cnt)
ori[rep_final_pos] = ori[target_pos[r_pos_cat].repeat(reps)]
rep_final_pos_list = []
lo = 0
for i in range(len(r_pos)):
r_one_times = r_pos[i].shape[0]
r_one_nums = r_one_times * reps
rep_final_pos_list.append(rep_final_pos[lo: lo + r_one_nums].reshape(r_one_times, reps))
lo += r_one_nums
return ori, rep_final_pos_list
def _test_get_rep_pos():
tokenized = torch.LongTensor([0, 1, 2, 2, 3, 4, 5, 6, 7, 99] + [99] * 20)
print('[from]:', tokenized)
rep_tokens = [2, 6]
rep_times = 2
rep_pos = get_rep_pos(tokenized, rep_tokens)
print('[rep_pos]:', rep_pos)
res, rep_pos_final = shift_tensor_dim0(tokenized, rep_pos, rep_times)
print('[to]:', res)
print('[final pos]:', rep_pos_final)
def _test_shift_tensor_dim0():
embedded = torch.arange(20)
print(embedded)
pos = np.array([3, 6, 8])
times = 1
output = shift_tensor_dim0(embedded, pos, times)
print(output)
if __name__ == "__main__":
_test_get_rep_pos()
|