Spaces:
No application file
No application file
File size: 7,114 Bytes
73f4c20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
"""Implement other PRF functions (These all vary only how they generate a single hash from the tokens in the context).
Can be hooked into existing WatermarkLogitsProcessor as modified base class WatermarkBase, see implementation in
extended_watermark_processor.py
"""
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from itertools import combinations
from functools import cache
# Key properties of a hashing scheme
props = {
"prf_type": str, # string name of the underlying PRF mapping multiple token ids to a random seed
"context_width": int, # this is h in the paper, how many previous tokens should be considered for each PRF
"self_salt": bool, # Use the rules laid in robust-watermarking to use the token itself to seed and possibly reject its own list
"hash_key": int, # integer, large prime, used to move seed away from low-entrop bit sequences in PRF chosen above
}
def seeding_scheme_lookup(seeding_scheme: str):
if not isinstance(seeding_scheme, str):
raise ValueError("Seeding scheme should be a string summarizing the procedure.")
if seeding_scheme == "simple_1" or seeding_scheme == "lefthash":
# Default, simple bigram hash # alias for ff-additive_prf-1-False-15485863
prf_type = "additive_prf"
context_width = 1
self_salt = False
hash_key = 15485863
elif seeding_scheme == "algorithm-3" or seeding_scheme == "selfhash":
prf_type = "anchored_minhash_prf"
context_width = 4
self_salt = True
hash_key = 15485863
elif seeding_scheme == "minhash":
prf_type = "minhash_prf"
context_width = 4
self_salt = False
hash_key = 15485863
elif seeding_scheme == "skipgram":
prf_type = "skipgram_prf"
context_width = 5
self_salt = False
hash_key = 15485863
elif seeding_scheme.startswith("ff"): # freeform seeding scheme API - only use for experimenting
# expects strings of the form ff-additive_prf-4-True-hash or ff-additive_prf-5-True (hash key is optional)
split_scheme = seeding_scheme.split("-")
prf_type = str(split_scheme[1])
context_width = int(split_scheme[2])
self_salt = split_scheme[3] == "True"
if len(split_scheme) == 5:
hash_key = int(split_scheme[4])
else:
hash_key = 15485863
else:
raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try 'simple_1'?")
assert prf_type in prf_lookup.keys()
return prf_type, context_width, self_salt, hash_key
def multiplicative_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
return salt_key * input_ids.prod().item()
def additive_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
return salt_key * input_ids.sum().item()
def minfunc_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
# not a great idea for non-random input ids as in text
return salt_key * input_ids.min().item()
def simple_skip_prf(input_ids: torch.LongTensor, salt_key: int, k=2) -> int:
# k is the skip distance
return hashint(salt_key * input_ids[::k]).prod().item()
def skipgram_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
# maximum distance skipgram within context
return hashint(salt_key * input_ids[0]).item()
def anchored_skipgram_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
# maximum distance skipgram within context
return (hashint(salt_key * input_ids[0]) * hashint(salt_key * input_ids[anchor])).item()
def minhash_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
# slightly less not the greatest idea for non-random input ids as in text
return hashint(salt_key * input_ids).min().item()
def anchored_minhash_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
# Anchor to one key to produce a min over pairs again
return (salt_key * hashint(input_ids) * hashint(input_ids[anchor])).min().item()
def minskipgram_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
# min over all skipgrams in context, k=2 is all pairs
skipgrams = torch.as_tensor(list(combinations(hashint(salt_key * input_ids), 2)))
return skipgrams.prod(dim=1).min().item()
def noncomm_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
key = torch.as_tensor(salt_key, dtype=torch.long)
for entry in input_ids:
key *= hashint(key * entry)
key %= 2**32
return key.item()
def position_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
return (salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device)).sum().item()
prf_lookup = {
"multiplicative_prf": multiplicative_prf,
"additive_prf": additive_prf,
"minfunc_prf": minfunc_prf,
"simple_skip_prf": simple_skip_prf,
"skipgram_prf": skipgram_prf,
"anchored_skipgram_prf": anchored_skipgram_prf,
"minhash_prf": minhash_prf,
"anchored_minhash_prf": anchored_minhash_prf,
"minskipgram_prf": minskipgram_prf,
"noncomm_prf": noncomm_prf,
"position_prf": position_prf,
}
# Generate a global permute table once at startup
rng = torch.Generator(device=torch.device("cpu"))
rng.manual_seed(2971215073) # fib47 is prime
table_size = 1_000_003
fixed_table = torch.randperm(1_000_003, device=torch.device("cpu"), generator=rng) # actually faster than I thought
def hashint(integer_tensor: torch.LongTensor) -> torch.LongTensor:
"""Sane version, in the end we only need a small permutation table."""
return fixed_table[integer_tensor.cpu() % table_size] + 1 # minor cheat here, this function always return CPU values
def _hashint_avalanche_tensor(integer_tensor: torch.LongTensor):
"""http://burtleburtle.net/bob/hash/integer.html, ported into pytorch, runs on tensors. Apparently a decent avalanche."""
i = integer_tensor.to(torch.int32).clone() # or torch.int16?
i -= i << 6
i ^= i >> 17
i -= i << 9
i ^= i << 4
i -= i << 3
i ^= i << 10
i ^= i >> 15
return i.to(torch.long)
@cache
def _hashint_avalanche_int(integer: int):
"""http://burtleburtle.net/bob/hash/integer.html, runs in base python, caches based on access.
Does this make sense for signed 64bit ints?"""
i = integer % (2**32)
i -= i << 6
i ^= i >> 17
i -= i << 9
i ^= i << 4
i -= i << 3
i ^= i << 10
i ^= i >> 15
return i
|