Spaces:
No application file
No application file
# Basic imports | |
import os | |
from functools import partial | |
from argparse import Namespace | |
import numpy as np | |
# HF classses | |
from transformers import AutoTokenizer | |
from datasets import Dataset, concatenate_datasets | |
# watermarking micro lib | |
from watermark import (BlacklistLogitsProcessor, | |
compute_bl_metrics) | |
# some file i/o helpers | |
from io_utils import read_jsonlines, read_json | |
from watermark import compute_bl_metrics, BlacklistLogitsProcessor | |
########################################################################### | |
# Compute E[wl] for each example | |
########################################################################### | |
def expected_whitelist(example, | |
idx, | |
exp_wl_coef: float == None, | |
drop_spike_entropies: bool = False): | |
assert "spike_entropies" in example, "Need to construct bl processor with store_spike_ents=True to compute them in post" | |
num_toks_gend = example["w_bl_num_tokens_generated"] | |
avg_spike_ent = np.mean(example["spike_entropies"]) | |
example.update({"avg_spike_entropy":avg_spike_ent}) | |
if drop_spike_entropies: del example["spike_entropies"] | |
exp_num_wl = (exp_wl_coef*num_toks_gend)*avg_spike_ent | |
var_num_wl = num_toks_gend*exp_wl_coef*avg_spike_ent*(1-(exp_wl_coef*avg_spike_ent)) | |
example.update({"w_bl_exp_num_wl_tokens":exp_num_wl}) | |
example.update({"w_bl_var_num_wl_tokens":var_num_wl}) | |
example.update({"exp_wl_coef":exp_wl_coef}) | |
if num_toks_gend > 0: | |
example.update({"w_bl_exp_whitelist_fraction":exp_num_wl/num_toks_gend, | |
"w_bl_var_whitelist_fraction":var_num_wl/num_toks_gend}) | |
else: | |
example.update({"w_bl_exp_whitelist_fraction":-1, | |
"w_bl_var_whitelist_fraction":-1}) | |
return example | |
from typing import Callable | |
def add_metadata(ex, meta_table=None): | |
ex.update(meta_table) | |
return ex | |
def str_replace_bug_check(example,idx): | |
baseline_before = example["baseline_completion"] | |
example["baseline_completion"] = baseline_before.replace(example["truncated_input"][:-1],"") | |
if example["baseline_completion"] != baseline_before: | |
print("baseline input replacement bug occurred, skipping row!") | |
return False | |
else: | |
return True | |
def load_all_datasets(run_names: list[str]=None, | |
base_run_dir: str=None, | |
meta_name: str=None, | |
gen_name: str=None, | |
apply_metric_func: bool=False, | |
convert_to_pandas: bool = False, | |
drop_buggy_rows: bool = False, | |
limit_output_tokens: int = 0, | |
save_ds: bool = True, | |
save_dir: str=None): | |
print(f"Loading {len(run_names)} datasets from {base_run_dir}...") | |
if not isinstance(gen_name, Callable): | |
file_check = lambda name: os.path.exists(f"{base_run_dir}/{name}/{gen_name}") | |
assert all([file_check(name) for name in run_names]), f"Make sure all the run dirs contain the required data files: {meta_name} and {gen_name}" | |
all_datasets = [] | |
for i,run_name in enumerate(run_names): | |
print(f"[{i}] Loading dataset") | |
run_base_dir = f"{base_run_dir}/{run_name}" | |
gen_table_meta_path = f"{run_base_dir}/{meta_name}" | |
if isinstance(gen_name, Callable): | |
gen_table_path = f"{run_base_dir}/{gen_name(run_name)}" | |
else: | |
gen_table_path = f"{run_base_dir}/{gen_name}" | |
# load the raw files | |
gen_table_meta = read_json(gen_table_meta_path) | |
gen_table_lst = [ex for ex in read_jsonlines(gen_table_path)] | |
gen_table_ds = Dataset.from_list(gen_table_lst) | |
print(f"Original dataset length={len(gen_table_ds)}") | |
# drop the rows where the string replace thing happens | |
if drop_buggy_rows: | |
gen_table_ds_filtered = gen_table_ds.filter(str_replace_bug_check,batched=False,with_indices=True) | |
else: | |
gen_table_ds_filtered = gen_table_ds | |
# enrich all rows with the run metadata | |
add_meta = partial( | |
add_metadata, | |
meta_table=gen_table_meta | |
) | |
gen_table_w_meta = gen_table_ds_filtered.map(add_meta, batched=False) | |
# optionally, apply the metric function(s) - somewhat expensive | |
# want to do this here rather than at end because you need each run's tokenizer | |
# though tbh it would be odd if they're not the same, but you can check that at the end | |
if apply_metric_func: | |
tokenizer = AutoTokenizer.from_pretrained(gen_table_meta["model_name"]) | |
comp_bl_metrics = partial( | |
compute_bl_metrics, | |
tokenizer=tokenizer, | |
hf_model_name=gen_table_meta["model_name"], | |
initial_seed=gen_table_meta["initial_seed"], | |
dynamic_seed=gen_table_meta["dynamic_seed"], | |
bl_proportion=gen_table_meta["bl_proportion"], | |
use_cuda=True, # this is obvi critical to match the pseudorandomness | |
record_hits=True, | |
limit_output_tokens=limit_output_tokens, | |
) | |
gen_table_w_bl_metrics = gen_table_w_meta.map(comp_bl_metrics, batched=False, with_indices=True) | |
# Construct the blacklist processor so you can get the expectation coef | |
all_token_ids = list(tokenizer.get_vocab().values()) | |
vocab_size = len(all_token_ids) | |
args = Namespace() | |
args.__dict__.update(gen_table_meta) | |
bl_processor = BlacklistLogitsProcessor(bad_words_ids=None, | |
store_bl_ids=False, | |
store_spike_ents=True, | |
eos_token_id=tokenizer.eos_token_id, | |
vocab=all_token_ids, | |
vocab_size=vocab_size, | |
bl_proportion=args.bl_proportion, | |
bl_logit_bias=args.bl_logit_bias, | |
bl_type=args.bl_type, | |
initial_seed= args.initial_seed, | |
dynamic_seed=args.dynamic_seed) | |
if "spike_entropies" in gen_table_w_bl_metrics.column_names: | |
comp_exp_num_wl = partial( | |
expected_whitelist, | |
exp_wl_coef=bl_processor.expected_wl_coef, | |
drop_spike_entropies=False, | |
# drop_spike_entropies=True, | |
) | |
gen_table_w_spike_ents = gen_table_w_bl_metrics.map(comp_exp_num_wl, batched=False, with_indices=True) | |
final_single_run_ds = gen_table_w_spike_ents | |
else: | |
final_single_run_ds = gen_table_w_bl_metrics | |
else: | |
final_single_run_ds = gen_table_w_meta | |
all_datasets.append(final_single_run_ds) | |
ds = concatenate_datasets(all_datasets) | |
if save_ds: | |
ds.save_to_disk(save_dir) | |
if convert_to_pandas: | |
df = ds.to_pandas() | |
return df | |
else: | |
return ds | |
output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_large_sweep" | |
# output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_large_sweep_downsize" | |
# output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_large_sweep_downsize" | |
# output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_greedy_redo" | |
# output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_greedy_gamma_0-25" | |
run_names = list(filter(lambda name: os.path.exists(f"{output_dir}/{name}/gen_table_w_metrics.jsonl"), sorted(os.listdir(output_dir)))) | |
run_names = list(filter(lambda name: "realnewslike" in name, run_names)) | |
# run_names = list(filter(lambda name: "pile" in name, run_names)) | |
# run_names = list(filter(lambda name: "c4_en" in name, run_names)) | |
# output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_attacked_greedy_updated" | |
# # output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_attacked_new" | |
# run_names = list(filter(lambda name: os.path.exists(f"{output_dir}/{name}/gen_table_w{('_'+name) if 't5' in name else ''}_attack_metrics.jsonl"), sorted(os.listdir(output_dir)))) | |
# run_names = list(filter(lambda name: os.path.exists(f"{output_dir}/{name}/gen_table_w_attack_metrics.jsonl"), sorted(os.listdir(output_dir)))) | |
runs_to_load = run_names | |
print(len(run_names)) | |
for name in run_names: print(name) | |
runs_ready = [os.path.exists(f"{output_dir}/{name}/gen_table_w_metrics.jsonl") for name in runs_to_load] | |
# runs_ready = [os.path.exists(f"{output_dir}/{name}/gen_table_w_attack_metrics.jsonl") for name in runs_to_load] | |
print(f"all runs ready? {all(runs_ready)}\n{runs_ready}") | |
# save_name = "analysis_ds_1-21_greedy_redo" | |
# save_name = "analysis_ds_1-21_greedy_redo_truncated" | |
# save_name = "analysis_ds_1-21_greedy_redo_truncated_sanity_check" | |
# save_name = "analysis_ds_1-19_realnews_1-3_v2_hitlist_check" | |
# save_name = "analysis_ds_1-20_more_attack" | |
# save_name = "analysis_ds_1-23_greedy_gamma_0-25_truncated" | |
# save_name = "analysis_ds_1-21_greedy_attacked_updated_truncated" | |
# save_name = "analysis_ds_1-23_pile_1-3" | |
# save_name = "analysis_ds_1-23_en_1-3" | |
save_name = "analysis_ds_1-30_realnews_2-7" | |
save_dir = f"input/{save_name}" | |
raw_data = load_all_datasets(run_names=runs_to_load, | |
base_run_dir=output_dir, | |
meta_name="gen_table_meta.json", | |
gen_name="gen_table_w_metrics.jsonl", | |
# gen_name="gen_table_w_attack_metrics.jsonl", | |
apply_metric_func=True, | |
# drop_buggy_rows=True, | |
drop_buggy_rows=False, | |
# limit_output_tokens=200, | |
convert_to_pandas=False, | |
save_ds=True, | |
save_dir=save_dir) | |
print(f"All finished with {save_dir}!!") | |