Spaces:
No application file
No application file
# coding=utf-8 | |
# Copyright 2023 Authors of "A Watermark for Large Language Models" | |
# available at https://arxiv.org/abs/2301.10226 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer | |
from utils.generation import tokenize_and_truncate, collate_batch | |
from metrics.repetition_diversity import ( | |
measure_repetition_and_diversity, | |
dummy_rep_div_result, | |
) | |
from metrics.p_sp import evaluate_p_sp | |
from metrics.detect_retrieval import detect_retrieval | |
from metrics.coherence import get_coherence_score | |
from metrics.mauve import get_mauve_score | |
from utils.hypothesis_testing import ( | |
chi_squared_runs_test, | |
F_succ_T_runs_dummy_dict_w_bins, | |
F_succ_T_runs_dummy_dict_no_bins, | |
T_and_F_runs_dummy_dict_w_bins, | |
T_and_F_runs_dummy_dict_no_bins, | |
) | |
from watermark_processor import WatermarkDetector | |
# These areguments are ignored when doing checks between meta file and cmdline args | |
NO_CHECK_ARGS = [ | |
"evaluation_metrics", | |
"verbose", | |
"wandb", | |
"wandb_entity", | |
"input_dir", | |
"output_dir", | |
"run_name", | |
"overwrite_output_file", | |
"overwrite_args", | |
"limit_rows", | |
"concat_rows", | |
"max_prefix_length", | |
] | |
def conditional_no_check_args(no_check_args, evaluation_metrics, args): | |
if "ppl" not in evaluation_metrics: | |
no_check_args.append("oracle_model_name_or_path") | |
no_check_args.append("load_fp16") | |
no_check_args.append("ppl_batch_size") | |
return no_check_args | |
# Series of configuration variables for the evaluation script | |
# These are the metrics we support | |
SUPPORTED_METRICS = [ | |
"z-score", | |
"windowed-z-score", | |
"run-len-chisqrd", | |
"ppl", | |
"diversity", | |
"repetition", | |
"p-sp", | |
"coherence", | |
"mauve", | |
"detect-retrieval", | |
"detectgpt", | |
] | |
# These are the output text columns we want to compute metrics on | |
OUTPUT_TEXT_COLUMN_NAMES = [ | |
"baseline_completion", | |
"no_wm_output", | |
"w_wm_output", | |
"w_wm_output_attacked", | |
] | |
# etc for other evaluation types | |
ZSCORE_TEXT_COLUMN_NAMES = OUTPUT_TEXT_COLUMN_NAMES | |
RUN_LEN_CHISQRD_TEXT_COLUMN_NAMES = OUTPUT_TEXT_COLUMN_NAMES | |
REPETITION_TEXT_COLUMN_NAMES = OUTPUT_TEXT_COLUMN_NAMES | |
# note the convention of including the input as 0th column | |
COHERENCE_TEXT_COLUMN_NAMES = ["truncated_input"] + OUTPUT_TEXT_COLUMN_NAMES | |
# These are the column pairs we want to compute p-sp for | |
OUTPUT_TEXT_PAIR_COLUMN_NAMES = [ | |
["baseline_completion", "no_wm_output"], | |
["baseline_completion", "w_wm_output"], | |
["baseline_completion", "w_wm_output_attacked"], | |
["no_wm_output", "w_wm_output"], | |
["w_wm_output", "w_wm_output_attacked"], | |
] | |
P_SP_TEXT_PAIR_COLUMN_NAMES = OUTPUT_TEXT_PAIR_COLUMN_NAMES | |
MAUVE_TEXT_PAIR_COLUMN_NAMES = OUTPUT_TEXT_PAIR_COLUMN_NAMES | |
ROC_TEST_STAT_SUFFIXES = [ | |
"z_score", | |
"win20-1_z_score", | |
"win40-1_z_score", | |
"winmax-1_z_score", | |
"run_len_chisqrd_statistic", | |
"retrieval_score", | |
"detectgpt_score_100_z", | |
"detectgpt_score_100_d", | |
] | |
FILTER_BY_COLUMNS = ["baseline_completion", "no_wm_output", "w_wm_output"] | |
def concat_rows(examples, tokenizer=None, args=None): | |
# concat the rows (there will be k rows per example) | |
# just joining the strings by a space | |
for col_name in examples.keys(): | |
if col_name in OUTPUT_TEXT_COLUMN_NAMES: | |
examples[col_name] = " ".join(examples[col_name]) | |
else: | |
# # check that all other columns have len args.concat_rows | |
# if len(examples[col_name]) != args.concat_rows: | |
# # append None to the col to make it the right length | |
# examples[col_name] = examples[col_name] + [None] * ( | |
# args.concat_rows - len(examples[col_name]) | |
# ) | |
# EH for now just set them to be the first element of their respective column | |
# quite mangled... | |
examples[col_name] = examples[col_name][0] | |
# Now, update the lengths | |
for col_name in OUTPUT_TEXT_COLUMN_NAMES: | |
if col_name in examples: | |
examples[f"{col_name}_length"] = len( | |
tokenizer(examples[col_name], add_special_tokens=False)["input_ids"] | |
) | |
return examples | |
def load_tokenizer(args): | |
model_name = args.model_name_or_path | |
print(f"Loading tokenizer for: {model_name}") | |
if "llama" in model_name: | |
tokenizer = LlamaTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token_id = 0 # unk | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
return tokenizer | |
def load_detector(args): | |
if "llama" in args.model_name_or_path: | |
tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path) | |
tokenizer.pad_token_id = 0 # unk | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) | |
device = "cuda" if (args.use_gpu and torch.cuda.is_available()) else "cpu" | |
watermark_detector = WatermarkDetector( | |
vocab=list(tokenizer.get_vocab().values()), | |
gamma=args.gamma, | |
seeding_scheme=args.seeding_scheme, | |
device=device, | |
tokenizer=tokenizer, | |
z_threshold=args.detection_z_threshold, | |
normalizers=args.normalizers, | |
ignore_repeated_ngrams=args.ignore_repeated_ngrams, | |
) | |
return watermark_detector | |
def compute_z_score( | |
example, | |
text_column_name=None, | |
watermark_detector=None, | |
args=None, | |
window_size=None, | |
window_stride=None, | |
): | |
# for now, don't get the green token mask | |
# if we're using normalizers | |
return_green_token_mask = args.return_green_token_mask | |
if args.normalizers != []: | |
return_green_token_mask = None | |
input_text = example[text_column_name] | |
error = False | |
if input_text == "": | |
error = True | |
else: | |
try: | |
score_dict = watermark_detector.detect( | |
input_text, | |
window_size=window_size, | |
window_stride=window_stride, | |
return_green_token_mask=return_green_token_mask, | |
return_prediction=False, # this conversion to "decision" only desired in demo context | |
convert_to_float=True, # this helps with integrity under NaNs | |
return_z_at_T=args.compute_scores_at_T, | |
) | |
except Exception as e: | |
print(e) | |
error = True | |
if error: | |
problem_text = f"'{input_text[:40]} {'[...]' if len(input_text) > 40 else ''}'" | |
if args.verbose: | |
print( | |
f"{(f'Windowed({window_size})' if window_size else '')} Detection error on text: {problem_text}" | |
) | |
# "Error string too short to compute metrics" | |
score_dict = watermark_detector.dummy_detect( | |
return_prediction=False, | |
return_green_token_mask=return_green_token_mask, | |
return_z_at_T=args.compute_scores_at_T, | |
) | |
# current detect logic causes issues bc it only reports this sometimes | |
score_dict.pop("confidence", None) | |
# replace every key name in score dict with the text_column_name + key name | |
# and then add them to the example dict | |
score_dict = { | |
text_column_name | |
+ (f"_win{window_size}-{window_stride}" if window_size else "") | |
+ "_" | |
+ k: v | |
for k, v in score_dict.items() | |
} | |
example.update(score_dict) | |
return example | |
def compute_z_scores(example, watermark_detector=None, args=None): | |
# this just iterates the z-score function over the columns we want to compute z-scores for | |
for col_name in ZSCORE_TEXT_COLUMN_NAMES: | |
if col_name in example: | |
example = compute_z_score( | |
example, text_column_name=col_name, watermark_detector=watermark_detector, args=args | |
) | |
return example | |
def compute_windowed_z_scores(example, watermark_detector=None, args=None): | |
# this iterates the z-score function over the columns we want to compute z-scores for | |
for col_name in ZSCORE_TEXT_COLUMN_NAMES: | |
if col_name in example: | |
for window_size in args.window_settings: | |
example = compute_z_score( | |
example, | |
text_column_name=col_name, | |
watermark_detector=watermark_detector, | |
args=args, | |
window_size=window_size, | |
window_stride=1, | |
) | |
return example | |
def compute_run_len_chisqrd_stat( | |
example, | |
text_column_name=None, | |
bool_arr_suffix=None, | |
bool_arr=None, | |
watermark_detector=None, # unused under the "z-score required to be run first" assumption | |
args=None, | |
force_error=False, | |
): | |
if bool_arr is not None: | |
bool_array = bool_arr | |
else: | |
bool_array_col_name = text_column_name + bool_arr_suffix | |
bool_array = example[bool_array_col_name] | |
if isinstance(bool_array, list): | |
bool_array = np.array(bool_array) | |
run_len_kwargs = dict( | |
bool_arr=bool_array, | |
succ_prob=1 - args.gamma, # this applies for both variants | |
variant=args.run_len_chisqrd_variant, | |
bin_spec=args.run_len_chisqrd_bin_spec, | |
verbose=False, # likely never in this context | |
invert_bools=False, # legacy | |
return_bin_counts=False, # debugging only, may not work currently | |
mask_zeros=args.run_len_chisqrd_mask_zeros, | |
mask_leading_bins=args.run_len_chisqrd_mask_leading_bins, | |
diy=False, # legacy | |
lambda_=args.run_len_chisqrd_lambda, | |
return_dict=True, # always in this context | |
) | |
error = True if force_error else False | |
try: | |
score_dict = chi_squared_runs_test(**run_len_kwargs) | |
except Exception as e: | |
print(e) | |
error = True | |
if error: | |
print(f"Run length test error, got: '{bool_array}'") | |
if run_len_kwargs["variant"] == "F_succ_T_runs": | |
if run_len_kwargs["return_bin_counts"]: | |
score_dict = F_succ_T_runs_dummy_dict_w_bins | |
else: | |
score_dict = F_succ_T_runs_dummy_dict_no_bins | |
elif run_len_kwargs["variant"] == "T_and_F_runs": | |
if run_len_kwargs["return_bin_counts"]: | |
score_dict = T_and_F_runs_dummy_dict_w_bins | |
else: | |
score_dict = T_and_F_runs_dummy_dict_no_bins | |
else: | |
raise ValueError("Unknown run length test variant and return_bin_counts setting") | |
# replace every key name in score dict with the text_column_name + key name | |
# and then add them to the example dict | |
score_dict = {text_column_name + "_run_len_chisqrd_" + k: v for k, v in score_dict.items()} | |
example.update(score_dict) | |
return example | |
def compute_run_len_chsqrd_stats( | |
example, | |
watermark_detector=None, | |
args=None, | |
bool_arr_suffix="_green_token_mask", | |
score_suffix="_run_len_chisqrd_statistic", | |
): | |
# this just iterates the run_len_chisqrd function over the columns we want to compute stats for | |
for col_name in RUN_LEN_CHISQRD_TEXT_COLUMN_NAMES: | |
if col_name in example: | |
if args.compute_scores_at_T: | |
full_bool_arr = example[f"{col_name}{bool_arr_suffix}"] | |
len_sequence = len(full_bool_arr) | |
if len_sequence < 1: | |
force_error = True | |
full_bool_arr = [None] # to cause loop to happen | |
len_sequence = 1 | |
else: | |
force_error = False | |
stats_at_T = [] | |
for t in range(1, len_sequence + 1): | |
bool_arr = full_bool_arr[:t] | |
example = compute_run_len_chisqrd_stat( | |
example, | |
bool_arr=bool_arr, # this overrides the normal access of the bool_arr | |
text_column_name=col_name, | |
bool_arr_suffix=bool_arr_suffix, | |
watermark_detector=watermark_detector, | |
args=args, | |
force_error=force_error, | |
) | |
stats_at_T.append(example[f"{col_name}{score_suffix}"]) | |
example[f"{col_name}{score_suffix}_at_T"] = stats_at_T | |
else: | |
example = compute_run_len_chisqrd_stat( | |
example, | |
text_column_name=col_name, | |
bool_arr_suffix=bool_arr_suffix, | |
watermark_detector=watermark_detector, | |
args=args, | |
) | |
return example | |
def load_oracle_model(args): | |
oracle_model_name = args.oracle_model_name_or_path | |
print(f"Loading oracle model: {oracle_model_name}") | |
if args.load_fp16: | |
oracle_model = AutoModelForCausalLM.from_pretrained( | |
oracle_model_name, torch_dtype=torch.float16, device_map="auto" | |
) | |
else: | |
oracle_model = AutoModelForCausalLM.from_pretrained(oracle_model_name) | |
if "llama" in oracle_model_name: | |
oracle_tokenizer = LlamaTokenizer.from_pretrained(oracle_model_name) | |
oracle_model.config.pad_token_id = oracle_tokenizer.pad_token_id = 0 # unk | |
oracle_model.config.bos_token_id = 1 | |
oracle_model.config.eos_token_id = 2 | |
else: | |
oracle_tokenizer = AutoTokenizer.from_pretrained(oracle_model_name) | |
if args.use_gpu: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if not args.load_fp16: | |
oracle_model = oracle_model.to(device) | |
else: | |
device = "cpu" | |
oracle_model.eval() | |
return oracle_model, oracle_tokenizer, device | |
from torch.nn import CrossEntropyLoss | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
def opt_unpooled_loss(logits, labels, model): | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss(reduction="none") | |
loss = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1)) | |
loss = loss.reshape(shift_logits.shape[:-1]) | |
# compute the mean for each elm in batch where the label is not pad | |
# we assume the losses are zero for pad indices | |
loss = torch.sum(loss, dim=-1) / torch.sum(shift_labels != -100, dim=-1) | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
) | |
UNPOOL_FN_TABLE = { | |
"opt": opt_unpooled_loss, | |
} | |
def get_unpool_fn(model_name): | |
if "opt" in model_name: | |
return UNPOOL_FN_TABLE["opt"] | |
else: | |
raise NotImplementedError(f"unpooling function not implemented for {model_name}") | |
def compute_ppl_batch( | |
prefix_and_output_text=None, | |
output_text=None, | |
oracle_model_name=None, | |
oracle_model=None, | |
oracle_tokenizer=None, | |
data_collator=None, | |
): | |
inputs = [] | |
labels = [] | |
for idx in range(len(prefix_and_output_text)): | |
tokd_prefix = tokenize_and_truncate( | |
{"text": prefix_and_output_text[idx]}, | |
completion_length=0, | |
hf_model_name=oracle_model_name, | |
tokenizer=oracle_tokenizer, | |
truncate_left=True, # we add this to cover if the generation is longer than the oracle's max length | |
model_max_length=oracle_model.config.max_position_embeddings, | |
)["input_ids"] | |
# if only want to score the "generation" part we need the suffix tokenization length | |
tokd_suffix = tokenize_and_truncate( | |
{"text": output_text[idx]}, | |
completion_length=0, | |
hf_model_name=oracle_model_name, | |
tokenizer=oracle_tokenizer, | |
)["input_ids"] | |
tokd_labels = tokd_prefix.clone().detach() | |
tokd_labels[:, : tokd_labels.shape[1] - tokd_suffix.shape[1] + 1] = -100 | |
inputs.append(tokd_prefix) | |
labels.append(tokd_labels) | |
inputs = collate_batch(input_ids=inputs, collator=data_collator).to(oracle_model.device) | |
labels = collate_batch(input_ids=labels, collator=data_collator).to(oracle_model.device) | |
labels[labels == oracle_tokenizer.pad_token_id] = -100 # mask out pad tokens for loss | |
with torch.no_grad(): | |
pooled_outputs = oracle_model(input_ids=inputs, labels=labels) | |
outputs = get_unpool_fn(oracle_model_name)(pooled_outputs.logits, labels, oracle_model) | |
loss = ( | |
outputs.loss | |
) # avg CE loss all sequence positions (except where labels -100, i.e. pad) | |
# ppl = torch.tensor(math.exp(loss)) | |
ppl = torch.exp(loss) | |
return loss.tolist(), ppl.tolist() | |
def evaluate_ppl( | |
examples: dict, | |
oracle_model_name=None, | |
oracle_model=None, | |
oracle_tokenizer=None, | |
data_collator=None, | |
): | |
inputs_plus_baseline_outputs = [] | |
baseline_outputs = [] | |
inputs_plus_no_wm_outputs = [] | |
no_wm_outputs = [] | |
inputs_plus_w_wm_outputs = [] | |
w_wm_outputs = [] | |
inputs_plus_w_wm_output_attackeds = [] | |
w_wm_output_attackeds = [] | |
for idx in range(len(examples["truncated_input"])): | |
# pull out the required fields from the pipeline results | |
inputs_plus_baseline_output = ( | |
f"{examples['truncated_input'][idx]}{examples['baseline_completion'][idx]}" | |
) | |
baseline_output = f"{examples['baseline_completion'][idx]}" | |
inputs_plus_no_wm_output = ( | |
f"{examples['truncated_input'][idx]}{examples['no_wm_output'][idx]}" | |
) | |
no_wm_output = f"{examples['no_wm_output'][idx]}" | |
inputs_plus_w_wm_output = ( | |
f"{examples['truncated_input'][idx]}{examples['w_wm_output'][idx]}" | |
) | |
w_wm_output = f"{examples['w_wm_output'][idx]}" | |
if "w_wm_output_attacked" in examples: | |
inputs_plus_w_wm_output_attacked = ( | |
f"{examples['truncated_input'][idx]}{examples['w_wm_output_attacked'][idx]}" | |
) | |
w_wm_output_attacked = f"{examples['w_wm_output_attacked'][idx]}" | |
# add to lists | |
inputs_plus_baseline_outputs.append(inputs_plus_baseline_output) | |
baseline_outputs.append(baseline_output) | |
inputs_plus_no_wm_outputs.append(inputs_plus_no_wm_output) | |
no_wm_outputs.append(no_wm_output) | |
inputs_plus_w_wm_outputs.append(inputs_plus_w_wm_output) | |
w_wm_outputs.append(w_wm_output) | |
if "w_wm_output_attacked" in examples: | |
inputs_plus_w_wm_output_attackeds.append(inputs_plus_w_wm_output_attacked) | |
w_wm_output_attackeds.append(w_wm_output_attacked) | |
# add metrics | |
loss, ppl = compute_ppl_batch( | |
inputs_plus_baseline_outputs, | |
baseline_outputs, | |
oracle_model_name, | |
oracle_model, | |
oracle_tokenizer, | |
data_collator=data_collator, | |
) | |
examples["baseline_completion_loss"] = loss | |
examples["baseline_completion_ppl"] = ppl | |
loss, ppl = compute_ppl_batch( | |
inputs_plus_no_wm_outputs, | |
no_wm_outputs, | |
oracle_model_name, | |
oracle_model, | |
oracle_tokenizer, | |
data_collator=data_collator, | |
) | |
examples["no_wm_output_loss"] = loss | |
examples["no_wm_output_ppl"] = ppl | |
loss, ppl = compute_ppl_batch( | |
inputs_plus_w_wm_outputs, | |
w_wm_outputs, | |
oracle_model_name, | |
oracle_model, | |
oracle_tokenizer, | |
data_collator=data_collator, | |
) | |
examples["w_wm_output_loss"] = loss | |
examples["w_wm_output_ppl"] = ppl | |
if "w_wm_output_attacked" in examples: | |
loss, ppl = compute_ppl_batch( | |
inputs_plus_w_wm_output_attackeds, | |
w_wm_output_attackeds, | |
oracle_model_name, | |
oracle_model, | |
oracle_tokenizer, | |
data_collator=data_collator, | |
) | |
examples["w_wm_output_attacked_loss"] = loss | |
examples["w_wm_output_attacked_ppl"] = ppl | |
return examples | |
def compute_repetition_diversity(example, include_repetition=False, include_diversity=False): | |
for col_name in REPETITION_TEXT_COLUMN_NAMES: | |
if col_name in example: | |
try: | |
results_tuple = measure_repetition_and_diversity(example[col_name]) | |
except Exception as e: | |
print( | |
f"Error for '{col_name}' computing repetition and diversity on text: '{example[col_name]}'\nError:{e}" | |
) | |
results_tuple = dummy_rep_div_result | |
if include_repetition: | |
# returns pred_seq_2, pred_seq_3, pred_seq_4, pred_div | |
# add each key from the result tuple to the example, prepending the col_name | |
metrics_dict = {f"{col_name}_{key}": value for key, value in results_tuple.items()} | |
example.update(metrics_dict) | |
if include_diversity: | |
# returns diversity only | |
example[f"{col_name}_diversity"] = results_tuple["diversity"] | |
example[f"{col_name}_log_diversity"] = results_tuple["log_diversity"] | |
return example | |
def compute_p_sp(dataset): | |
for column_pair in P_SP_TEXT_PAIR_COLUMN_NAMES: | |
if column_pair[0] in dataset.features and column_pair[1] in dataset.features: | |
p_sp_scores = evaluate_p_sp(dataset[column_pair[0]], dataset[column_pair[1]]) | |
if f"{column_pair[0]}_vs_{column_pair[1]}_p_sp" in dataset.features: | |
print( | |
f"WARNING: Removing existing {column_pair[0]}_vs_{column_pair[1]}_p_sp column because it was already present" | |
) | |
dataset = dataset.remove_columns([f"{column_pair[0]}_vs_{column_pair[1]}_p_sp"]) | |
dataset = dataset.add_column(f"{column_pair[0]}_vs_{column_pair[1]}_p_sp", p_sp_scores) | |
return dataset | |
def compute_mauve(dataset): | |
""" | |
The current convention is to repeat the score for all rows in the dataset | |
under the assumption that the final score will be retreived via | |
a groupby + take(1) operation or similar (even a `mean` would be fine) | |
""" | |
for column_pair in MAUVE_TEXT_PAIR_COLUMN_NAMES: | |
if column_pair[0] in dataset.features and column_pair[1] in dataset.features: | |
mauve_score = get_mauve_score(dataset[column_pair[0]], dataset[column_pair[1]]) | |
if f"{column_pair[0]}_vs_{column_pair[1]}_mauve" in dataset.features: | |
print( | |
f"WARNING: Removing existing {column_pair[0]}_vs_{column_pair[1]}_mauve column because it was already present" | |
) | |
dataset = dataset.remove_columns([f"{column_pair[0]}_vs_{column_pair[1]}_mauve"]) | |
dataset = dataset.add_column( | |
f"{column_pair[0]}_vs_{column_pair[1]}_mauve", [mauve_score] * len(dataset) | |
) | |
return dataset | |
def compute_coherence(dataset): | |
""" | |
Assumes the first column is the prefix or prompt to the model | |
and the current convention is to repeat the score for all rows in the dataset | |
under the assumption that the final score will be retreived via | |
a groupby + take(1) operation or similar (even a `mean` would be fine) | |
""" | |
prefix_column = dataset[COHERENCE_TEXT_COLUMN_NAMES[0]] | |
for generated_text_column in COHERENCE_TEXT_COLUMN_NAMES[1:]: | |
if generated_text_column in dataset.features: | |
coherence_score = get_coherence_score(prefix_column, dataset[generated_text_column]) | |
if f"{generated_text_column}_coherence" in dataset.features: | |
print( | |
f"WARNING: Removing existing {generated_text_column}_coherence column because it was already present" | |
) | |
dataset = dataset.remove_columns([f"{generated_text_column}_coherence"]) | |
dataset = dataset.add_column( | |
f"{generated_text_column}_coherence", [coherence_score] * len(dataset) | |
) | |
return dataset | |
def compute_detect_retrieval(dataset, args=None): | |
# if we don't have the attacked column, | |
# then mock it using the w_wm_output, just means the two score cols will be the same | |
# and we'll need to delete it after | |
was_real_attacked_ds = True | |
if "w_wm_output_attacked" not in dataset.features: | |
# were faking it | |
was_real_attacked_ds = False | |
dataset = dataset.add_column("w_wm_output_attacked", dataset[args.retrieval_db_column]) | |
dataset = dataset.add_column( | |
"w_wm_output_attacked_length", dataset[f"{args.retrieval_db_column}_length"] | |
) | |
human_detect, paraphrase_detect, generation_detect = detect_retrieval(dataset, args=args) | |
if f"baseline_completion_retrieval_score" in dataset.features: | |
print( | |
f"WARNING: Removing existing baseline_completion_retrieval_score column because it was already present" | |
) | |
dataset = dataset.remove_columns(["baseline_completion_retrieval_score"]) | |
dataset = dataset.add_column(f"baseline_completion_retrieval_score", human_detect) | |
if f"{args.retrieval_db_column}_retrieval_score" in dataset.features: | |
print( | |
f"WARNING: Removing existing {args.retrieval_db_column}_retrieval_score column because it was already present" | |
) | |
dataset = dataset.remove_columns([f"{args.retrieval_db_column}_retrieval_score"]) | |
dataset = dataset.add_column(f"{args.retrieval_db_column}_retrieval_score", generation_detect) | |
if was_real_attacked_ds: | |
if f"w_wm_output_attacked_retrieval_score" in dataset.features: | |
print( | |
f"WARNING: Removing existing w_wm_output_attacked_retrieval_score column because it was already present" | |
) | |
dataset = dataset.remove_columns(["w_wm_output_attacked_retrieval_score"]) | |
dataset = dataset.add_column(f"w_wm_output_attacked_retrieval_score", paraphrase_detect) | |
# else this is a dummy column, so delete it | |
else: | |
# sanity check that the scores are the same for the dummy column and the original | |
assert all( | |
[ | |
s1 == s2 if (not np.isnan(s1) and not np.isnan(s2)) else True | |
for s1, s2 in zip(paraphrase_detect, generation_detect) | |
] | |
) | |
dataset = dataset.remove_columns(["w_wm_output_attacked", "w_wm_output_attacked_length"]) | |
return dataset | |
from utils.submitit import str2bool | |
def scheme_hparam_extractor(x): | |
is_ff = "ff" in x | |
is_simple_1 = ("simple_1" in x) or ("lefthash" in x) | |
is_algorithm_3 = ("algorithm-3" in x) or ("selfhash" in x) | |
is_anchored = "anchored" in x | |
x = x.replace("ff-", "") | |
x = x.replace("_prf", "") | |
x = x.replace("anchored_", "") | |
tup_x = x.split("-") | |
# turn into a dict repr | |
if is_ff: | |
x_dict = { | |
"prf_type": tup_x[0], | |
"anchored": is_anchored, | |
"context_width": int(tup_x[1]), | |
"self_salt": str2bool(tup_x[2]), | |
} | |
elif is_simple_1: | |
x_dict = { | |
"prf_type": "additive", | |
"anchored": False, | |
"context_width": 1, | |
"self_salt": False, | |
} | |
elif is_algorithm_3: | |
x_dict = { | |
"prf_type": "minhash", | |
"anchored": True, | |
"context_width": 4, | |
"self_salt": True, | |
} | |
else: | |
raise ValueError(f"Invalid scheme name {x} found.") | |
return x_dict | |