Spaces:
No application file
No application file
# coding=utf-8 | |
# Copyright 2023 Authors of "A Watermark for Large Language Models" | |
# available at https://arxiv.org/abs/2301.10226 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import argparse | |
from functools import partial | |
from tqdm import tqdm | |
import wandb | |
print(f"Current huggingface cache dir: {os.environ['HF_HOME']}") | |
# HF classses | |
from transformers import LogitsProcessorList, DataCollatorWithPadding | |
# better bool flag type for argparse | |
from utils.submitit import str2bool | |
# some file i/o helpers | |
from utils.io import write_jsonlines, write_json | |
# watermarking functionality | |
from watermark_processor import WatermarkLogitsProcessor | |
# generation pipeline helpers | |
from utils.generation import ( | |
MAX_GENERATIONS, | |
load_model, | |
load_hf_dataset, | |
check_input_lengths, | |
check_output_lengths, | |
tokenize_for_generation, | |
generate, | |
) | |
def main(args): | |
########################################################################### | |
# Start logging | |
########################################################################### | |
# storing slurm info to allow auditing logfiles later | |
args.SLURM_JOB_ID = os.getenv("SLURM_JOB_ID") | |
args.SLURM_ARRAY_JOB_ID = os.getenv("SLURM_ARRAY_JOB_ID") | |
args.SLURM_ARRAY_TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID") | |
if args.wandb: | |
# start a new wandb run to track this experiment, will send data to it later | |
run = wandb.init( | |
# set the wandb project where this run will be logged | |
project=args.wandb_project, | |
entity=args.wandb_entity, | |
name=f"{args.run_name}", | |
# track hyperparameters and run metadata | |
config=args, | |
tags=args.wandb_tags, | |
) | |
########################################################################### | |
# Create the output dir | |
########################################################################### | |
print(f"Output dir for this run: {args.output_dir}") | |
# notify if exists | |
if os.path.exists(args.output_dir): | |
print(f"Output dir for this run already exists!") | |
print(f"Contents: {sorted(os.listdir(args.output_dir))}") | |
else: | |
# create the output dir where run artifacts are stored | |
os.makedirs(args.output_dir) | |
########################################################################### | |
# Load the dataset | |
########################################################################### | |
# basic ops like shuffling and select are done in load fn | |
dataset = load_hf_dataset(args) | |
########################################################################### | |
# Instantiate model and tokenizer | |
########################################################################### | |
model, tokenizer, device = load_model(args) | |
########################################################################### | |
# Configure the prompt construction partial | |
########################################################################### | |
# Construct the data filtering/sampling scheme partials | |
token_kwargs = dict( | |
hf_model_name=args.model_name_or_path, | |
tokenizer=tokenizer, | |
args=args, | |
) | |
if args.input_truncation_strategy == "prompt_length": | |
token_kwargs.update(dict(min_prompt_tokens=args.min_prompt_tokens)) | |
elif args.input_truncation_strategy == "completion_length": | |
token_kwargs.update(dict(max_new_tokens=args.max_new_tokens)) | |
elif args.input_truncation_strategy == "no_truncation": | |
# truncate_input_for_prompt is a bool flag, that is set by | |
# the dataset loading function, semi-redundant, to make sure | |
# people are very aware of which input data style they are using | |
assert ( | |
args.truncate_input_for_prompt == False | |
), "Cannot truncate input for prompt if 'no_truncation' strategy is specified" | |
pass | |
else: | |
ValueError(f"Unknown input truncation strategy {args.input_truncation_strategy}") | |
tokenize_prompts = partial(tokenize_for_generation, **token_kwargs) | |
########################################################################### | |
# Configure the I/O data validation partials | |
########################################################################### | |
input_check_kwargs = dict( | |
min_sample_len=args.min_sample_tokens, | |
max_input_len=model.config.max_position_embeddings, | |
max_new_tokens=args.max_new_tokens, | |
) | |
if args.input_filtering_strategy == "prompt_length": | |
input_check_kwargs.update(dict(min_prompt_len=args.min_prompt_tokens, min_completion_len=0)) | |
elif args.input_filtering_strategy == "completion_length": | |
input_check_kwargs.update(dict(min_prompt_len=0, min_completion_len=args.max_new_tokens)) | |
elif args.input_filtering_strategy == "prompt_and_completion_length": | |
input_check_kwargs.update( | |
dict(min_prompt_len=args.min_prompt_tokens, min_completion_len=args.max_new_tokens) | |
) | |
elif args.input_filtering_strategy == "no_filter": | |
input_check_kwargs.update(dict(min_prompt_len=0, min_completion_len=0)) | |
else: | |
ValueError(f"Unknown input filtering strategy {args.input_filtering_strategy}") | |
input_check = partial(check_input_lengths, **input_check_kwargs) | |
if args.output_filtering_strategy == "max_new_tokens": | |
output_kwargs = dict(min_output_len=args.max_new_tokens) | |
elif args.output_filtering_strategy == "no_filter": | |
output_kwargs = dict(min_output_len=0) | |
else: | |
ValueError(f"Unknown output filtering strategy {args.output_filtering_strategy}") | |
output_check = partial(check_output_lengths, **output_kwargs) | |
########################################################################### | |
# Construct the watermark processor | |
########################################################################### | |
watermark_processor = WatermarkLogitsProcessor( | |
vocab=list(tokenizer.get_vocab().values()), | |
gamma=args.gamma, | |
delta=args.delta, | |
seeding_scheme=args.seeding_scheme, | |
store_spike_ents=args.store_spike_ents, | |
select_green_tokens=True, | |
) | |
########################################################################### | |
# Configure the generation partials | |
########################################################################### | |
gen_kwargs = dict(max_new_tokens=args.max_new_tokens) | |
# FIXME can add typica | |
if args.use_sampling: | |
gen_kwargs.update( | |
dict( | |
do_sample=True, | |
top_k=args.top_k, | |
top_p=args.top_p, | |
typical_p=args.typical_p, | |
temperature=args.sampling_temp, | |
) | |
) | |
else: | |
gen_kwargs.update(dict(num_beams=args.num_beams)) | |
generate_without_watermark = partial(model.generate, **gen_kwargs) | |
generate_with_watermark = partial( | |
model.generate, logits_processor=LogitsProcessorList([watermark_processor]), **gen_kwargs | |
) | |
# construct the collator | |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True, pad_to_multiple_of=8) | |
generation_partial = partial( | |
generate, | |
data_collator=data_collator, | |
generate_without_watermark=generate_without_watermark, | |
generate_with_watermark=generate_with_watermark, | |
watermark_processor=watermark_processor, | |
tokenizer=tokenizer, | |
device=device, | |
args=args, | |
) | |
########################################################################### | |
# Compose the partials to create the pipeline | |
########################################################################### | |
# tokenize and truncate the row inputs to create prompts according to the strategy spec'd above | |
dataset_w_prompts = dataset.map(tokenize_prompts, batched=False) | |
# filter the rows of the dataset based on length checks for the tokenized prompts and baseline completions | |
dataset_input_len_filtered = dataset_w_prompts.filter(input_check, batched=False) | |
# need to remove the input tensor column after this map | |
# bc it persists between the prompt creation and generation maps | |
columns_to_remove = args.columns_to_remove + ["input_ids"] | |
# call the generation partial on each prompt in the dataset | |
dataset_w_generations = dataset_input_len_filtered.map( | |
generation_partial, | |
batched=True, | |
batch_size=args.generation_batch_size, | |
remove_columns=columns_to_remove, | |
) | |
########################################################################### | |
# Main loop - actually executes the generation pipeline. | |
# and accumulates the result rows in a list, assumes list is "small"-ish | |
# and we aren't accumulating any tensors or other memory hogging artifacts | |
########################################################################### | |
processed_examples = [] | |
ds_iterator = iter(dataset_w_generations) | |
i = 0 | |
total_steps = 0 | |
pbar = tqdm(total=args.min_generations) | |
while i < args.min_generations: | |
try: | |
ex = next(ds_iterator) | |
total_steps += 1 | |
except StopIteration: | |
break | |
if args.verbose: | |
# log basics to stdout | |
print(f"#" * 80) | |
print(f"dataset index: {ex['idx']}") | |
print(f"orig_sample_length: {ex['orig_sample_length']}") | |
print(f"prompt_length: {ex['prompt_length']}") | |
print(f"real_completion_length: {ex['baseline_completion_length']}") | |
print(f"no_wm_output_length: {ex['no_wm_output_length']}") | |
print(f"w_wm_output_length: {ex['w_wm_output_length']}") | |
print(f"\ntruncated_input: ") | |
print(ex["truncated_input"]) | |
print(f"\nbaseline_completion: ") | |
print(ex["baseline_completion"]) | |
print(f"\nno_wm_output: ") | |
print(ex["no_wm_output"]) | |
print(f"\nw_wm_output: ") | |
print(ex["w_wm_output"]) | |
processed_examples.append(ex) | |
if output_check(ex): | |
i += 1 | |
pbar.update(1) | |
else: | |
print( | |
f"\n{i} of {len(processed_examples)} rows were satisfactory so far, {round(i/args.min_generations, 2)} of total.", | |
f"\nCurrent generation overhead ratio: {round(len(processed_examples)/(i+1), 3)}.", | |
) | |
# if using wandb, log progress to wandb | |
if args.wandb: | |
run.log( | |
{ | |
"num_satisfactory_samples": i, | |
"progress_ratio": i / args.min_generations, | |
"generation_overhead_ratio": len(processed_examples) / (i + 1), | |
"total_generated_samples": len(processed_examples), | |
}, | |
step=total_steps, | |
) | |
pbar.close() | |
print( | |
f"#" * 80, | |
f"\nGeneration output length check overhead was num rows processed={len(processed_examples)}", | |
f"for {args.min_generations} samples. Ratio: {round(len(processed_examples)/args.min_generations, 3)}", | |
) | |
if i < args.min_generations: | |
print( | |
f"#" * 80, | |
f"\nWarning, may have run out of data before {args.min_generations} satisfactory samples were generated. ", | |
f"\nNote, raw dataset limit was {args.limit_indices} rows.", | |
f"\n{len(processed_examples)} prompt passed input checks and yielded generations, and {i} passed output checks,", | |
f"\nProgress made: {round(i/args.min_generations, 2)}", | |
) | |
########################################################################### | |
# Generation jsonl dumping | |
########################################################################### | |
gen_table_meta_path = f"{args.output_dir}/gen_table_meta.json" | |
gen_table_path = f"{args.output_dir}/gen_table.jsonl" | |
safe_gen_table_path = f"{args.output_dir}/gen_table_safe.jsonl" | |
args.gen_table_already_existed = False | |
if os.path.exists(gen_table_path): | |
args.gen_table_already_existed = True | |
print(f"Found existing generation files at this output dir: {args.output_dir}") | |
if args.overwrite: | |
print("Overwriting old generation files.") | |
gen_table_path = gen_table_path | |
else: | |
print( | |
f"Writing generations at alternate, safe path and exiting. Note! this only works once. " | |
f"Safe version will get overwritten next time ... " | |
) | |
gen_table_path = safe_gen_table_path | |
gen_table_meta = args.__dict__ | |
gen_table = processed_examples | |
write_jsonlines(gen_table, gen_table_path) | |
write_json(gen_table_meta, gen_table_meta_path, indent=4) | |
# finish the wandb run | |
if args.wandb: | |
run.finish() | |
return # reload in separate script for metric measurement | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Run watermarked huggingface LM generation pipeline" | |
) | |
parser.add_argument( | |
"--model_name_or_path", | |
type=str, | |
default="facebook/opt-1.3b", | |
help="Main model, path to pretrained model or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--load_fp16", | |
type=str2bool, | |
default=True, | |
help="Whether to run model in float16 precsion.", | |
) | |
parser.add_argument( | |
"--use_gpu", | |
type=str2bool, | |
default=True, | |
help="Whether to run inference and watermark hashing/seeding/permutation on gpu.", | |
) | |
parser.add_argument( | |
"--dataset_name", | |
type=str, | |
default="c4", | |
help="The name of the dataset to use (via the datasets library).", | |
) | |
parser.add_argument( | |
"--dataset_config_name", | |
type=str, | |
default="realnewslike", | |
help="The configuration name of the dataset to use (via the datasets library).", | |
) | |
parser.add_argument( | |
"--dataset_split", | |
type=str, | |
default="train", | |
help="The split of the dataset to use (via the datasets library).", | |
) | |
parser.add_argument( | |
"--stream_dataset", | |
type=str2bool, | |
default=True, | |
help="Whether to stream the dataset from the web or download it locally.", | |
) | |
parser.add_argument( | |
"--columns_to_remove", | |
type=str, | |
default=None, | |
help="Comma separated list of columns to remove from the dataset before generation.", | |
) | |
parser.add_argument( | |
"--shuffle_dataset", | |
type=str2bool, | |
default=False, | |
help="Whether to shuffle the dataset before sampling.", | |
) | |
parser.add_argument( | |
"--shuffle_seed", | |
type=int, | |
default=1234, | |
help="The seed to use for dataset shuffle op.", | |
) | |
parser.add_argument( | |
"--shuffle_buffer_size", | |
type=int, | |
default=10_000, | |
help="The buffer size to use for dataset shuffle op - takes n rows first, then shuffles those indices", | |
) | |
parser.add_argument( | |
"--prompt_id", | |
type=int, | |
default=0, | |
help="If the dataset supports multiple instruction prompts, denotes which one to use. 0 is default/no prompt.", | |
) | |
parser.add_argument( | |
"--max_new_tokens", | |
type=int, | |
default=100, | |
help="The number of tokens to generate using the model, and the num tokens removed from real text sample", | |
) | |
parser.add_argument( | |
"--min_prompt_tokens", | |
type=int, | |
default=50, # 500 | |
help="The number of examples (first N) to process from the dataset.", | |
) | |
parser.add_argument( | |
"--min_sample_tokens", | |
type=int, | |
default=0, | |
help="The the minimum length of raw prompt samples to consider.", | |
) | |
parser.add_argument( | |
"--limit_indices", | |
type=int, | |
default=None, | |
help="The number of examples (first N) to pull from the dataset, if None, pull all, and then set this arg to the number of rows in the dataset.", | |
) | |
parser.add_argument( | |
"--min_generations", | |
type=int, | |
default=500, | |
help="The minimum number of valid generations according to the output check strat to sample.", | |
) | |
parser.add_argument( | |
"--input_truncation_strategy", | |
type=str, | |
default="completion_length", | |
choices=["no_truncation", "completion_length", "prompt_length"], | |
help="The strategy to use when tokenizing and truncating raw inputs to make prompts.", | |
) | |
parser.add_argument( | |
"--input_filtering_strategy", | |
type=str, | |
default="completion_length", | |
choices=["no_filter", "completion_length", "prompt_length", "prompt_and_completion_length"], | |
help="The strategy to use when tokenizing and truncating raw inputs to make prompts.", | |
) | |
parser.add_argument( | |
"--output_filtering_strategy", | |
type=str, | |
default="no_filter", | |
choices=["no_filter", "max_new_tokens"], | |
help=( | |
f"The strategy to use when filtering/skipping rows if the model didn't ", | |
f"generate enough tokens to facilitate analysis.", | |
), | |
) | |
parser.add_argument( | |
"--use_sampling", | |
type=str2bool, | |
default=False, | |
help=("Whether to perform sampling during generation. (non-greedy decoding)"), | |
) | |
parser.add_argument( | |
"--sampling_temp", | |
type=float, | |
default=0.7, | |
help="The temperature to use when generating using multinom sampling", | |
) | |
parser.add_argument( | |
"--top_k", | |
type=int, | |
default=0, | |
help="The top k to use when generating using top_k version of multinom sampling", | |
) | |
parser.add_argument( | |
"--top_p", | |
type=float, | |
default=1.0, | |
help="The top p to use when generating using top_p version of sampling", | |
) | |
parser.add_argument( | |
"--typical_p", | |
type=float, | |
default=1.0, | |
help="The typical p to use when generating using typical decoding version of multinom sampling", | |
) | |
parser.add_argument( | |
"--num_beams", | |
type=int, | |
default=1, | |
help="The number of beams to use where '1' is no beam search.", | |
) | |
parser.add_argument( | |
"--generation_seed", | |
type=int, | |
default=None, | |
help="Seed for setting the torch rng prior to generation using any decoding scheme with randomness.", | |
) | |
parser.add_argument( | |
"--generation_batch_size", | |
type=int, | |
default=4, | |
help="The batch size to use for generation.", | |
) | |
parser.add_argument( | |
"--seeding_scheme", | |
type=str, | |
default="simple_1", | |
help="The seeding procedure to use for the watermark.", | |
) | |
parser.add_argument( | |
"--gamma", | |
type=float, | |
default=0.25, | |
help="The ratio of tokens to put in the greenlist when splitting the vocabulary", | |
) | |
parser.add_argument( | |
"--delta", | |
type=float, | |
default=2.0, | |
help="The amount of bias (absolute) to add to the logits in the whitelist half of the vocabulary at every step", | |
) | |
parser.add_argument( | |
"--store_spike_ents", | |
type=str2bool, | |
default=True, | |
help=("Whether to store the spike entropies while generating with watermark processor. "), | |
) | |
parser.add_argument( | |
"--verbose", | |
type=str2bool, | |
default=False, | |
help="Whether to log the generations to stdout.", | |
) | |
parser.add_argument( | |
"--wandb", | |
type=str2bool, | |
default=False, | |
help="Whether to log to wandb.", | |
) | |
parser.add_argument( | |
"--wandb_project", | |
type=str, | |
default="lm-watermarking", | |
help="The name of the wandb project.", | |
) | |
parser.add_argument( | |
"--wandb_entity", | |
type=str, | |
default="jwkirchenbauer", | |
help="The wandb entity/user for the project.", | |
) | |
parser.add_argument( | |
"--wandb_tags", | |
type=str, | |
default="", | |
help="The comma separated list of tags to add to the wandb run.", | |
) | |
parser.add_argument( | |
"--run_name", | |
type=str, | |
default=None, | |
help="The unique name for the run.", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="./output", | |
help="The unique name for the run.", | |
) | |
parser.add_argument( | |
"--overwrite", | |
type=str2bool, | |
default=False, | |
help="Allow overwriting of old generation files at the same output location.", | |
) | |
args = parser.parse_args() | |
########################################################################### | |
# Argument validation and conditional setting | |
########################################################################### | |
# for removing some columns to save space | |
args.columns_to_remove = args.columns_to_remove.split(",") if args.columns_to_remove else [] | |
# if decoding scheme is not sampling, then set generation seed to None | |
# to avoid confusion and calling the torch rng unnecessarily | |
args.generation_seed = args.generation_seed if args.use_sampling else None | |
# -1 value for min_generations means no specified minimum | |
# with the assumption that the | |
if args.min_generations <= 0: | |
args.min_generations = MAX_GENERATIONS | |
print( | |
f"Warning: min_generations is -1. A hardcoded value of {MAX_GENERATIONS} will be used to limit the generation loop." | |
) | |
if args.limit_indices is None: | |
print("No limit_indices specified, pulling all examples from the dataset.") | |
else: | |
print(f"Limiting iteration to {args.limit_indices} examples from the dataset.") | |
# split wandb tags | |
if args.wandb_tags != "": | |
args.wandb_tags = args.wandb_tags.split(",") | |
else: | |
args.wandb_tags = [] | |
main(args) | |