File size: 4,921 Bytes
41aed49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import torch
from itertools import combinations
from functools import cache
# Key properties of a hashing scheme
HASHING_PROPERTIES = {
"prf_type": str, # Name of the PRF mapping tokens to a seed
"context_width": int, # Number of previous tokens considered
"self_salt": bool, # Use token itself to seed and possibly reject
"hash_key": int, # Large prime number for seed generation
}
def lookup_seeding_scheme(seeding_scheme: str):
if not isinstance(seeding_scheme, str):
raise ValueError("Seeding scheme should be a string summarizing the procedure.")
if seeding_scheme in ["simple_1", "lefthash"]:
prf_type = "additive_prf"
context_width = 1
self_salt = False
hash_key = 15485863
elif seeding_scheme in ["algorithm-3", "selfhash"]:
prf_type = "anchored_minhash_prf"
context_width = 4
self_salt = True
hash_key = 15485863
elif seeding_scheme == "minhash":
prf_type = "minhash_prf"
context_width = 4
self_salt = False
hash_key = 15485863
elif seeding_scheme == "skipgram":
prf_type = "skipgram_prf"
context_width = 5
self_salt = False
hash_key = 15485863
elif seeding_scheme.startswith("ff"):
split_scheme = seeding_scheme.split("-")
prf_type = str(split_scheme[1])
context_width = int(split_scheme[2])
self_salt = split_scheme[3] == "True"
hash_key = int(split_scheme[4]) if len(split_scheme) == 5 else 15485863
else:
raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try 'simple_1'?")
assert prf_type in prf_lookup.keys()
return prf_type, context_width, self_salt, hash_key
def prf_multiplicative(input_ids: torch.LongTensor, salt_key: int) -> int:
return salt_key * input_ids.prod().item()
def prf_additive(input_ids: torch.LongTensor, salt_key: int) -> int:
return salt_key * input_ids.sum().item()
def prf_minfunc(input_ids: torch.LongTensor, salt_key: int) -> int:
return salt_key * input_ids.min().item()
def prf_simple_skip(input_ids: torch.LongTensor, salt_key: int, k=2) -> int:
return hash_int(salt_key * input_ids[::k]).prod().item()
def prf_skipgram(input_ids: torch.LongTensor, salt_key: int) -> int:
return hash_int(salt_key * input_ids[0]).item()
def prf_anchored_skipgram(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
return (hash_int(salt_key * input_ids[0]) * hash_int(salt_key * input_ids[anchor])).item()
def prf_minhash(input_ids: torch.LongTensor, salt_key: int) -> int:
return hash_int(salt_key * input_ids).min().item()
def prf_anchored_minhash(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
return (salt_key * hash_int(input_ids) * hash_int(input_ids[anchor])).min().item()
def prf_minskipgram(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
skipgrams = torch.as_tensor(list(combinations(hash_int(salt_key * input_ids), 2)))
return skipgrams.prod(dim=1).min().item()
def prf_noncommutative(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
key = torch.as_tensor(salt_key, dtype=torch.long)
for entry in input_ids:
key *= hash_int(key * entry)
key %= 2**32
return key.item()
def prf_positional(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
return (salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device)).sum().item()
prf_lookup = {
"multiplicative_prf": prf_multiplicative,
"additive_prf": prf_additive,
"minfunc_prf": prf_minfunc,
"simple_skip_prf": prf_simple_skip,
"skipgram_prf": prf_skipgram,
"anchored_skipgram_prf": prf_anchored_skipgram,
"minhash_prf": prf_minhash,
"anchored_minhash_prf": prf_anchored_minhash,
"minskipgram_prf": prf_minskipgram,
"noncomm_prf": prf_noncommutative,
"position_prf": prf_positional,
}
# Generate a global permutation table once at startup
random_generator = torch.Generator(device=torch.device("cpu"))
random_generator.manual_seed(2971215073) # Fibonacci 47 is prime
table_size = 1_000_003
global_permutation_table = torch.randperm(table_size, device=torch.device("cpu"), generator=random_generator)
def hash_int(integer_tensor: torch.LongTensor) -> torch.LongTensor:
return global_permutation_table[integer_tensor.cpu() % table_size] + 1 # Ensure values are on CPU
def avalanche_hash_tensor(integer_tensor: torch.LongTensor):
i = integer_tensor.to(torch.int32).clone()
i -= i << 6
i ^= i >> 17
i -= i << 9
i ^= i << 4
i -= i << 3
i ^= i << 10
i ^= i >> 15
return i.to(torch.long)
@cache
def avalanche_hash_int(integer: int):
i = integer % (2**32)
i -= i << 6
i ^= i >> 17
i -= i << 9
i ^= i << 4
i -= i << 3
i ^= i << 10
i ^= i >> 15
return i
|