# coding=utf-8 # Copyright 2023 Authors of "A Watermark for Large Language Models" # available at https://arxiv.org/abs/2301.10226 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch # HF classes from datasets import load_dataset, IterableDataset from torch import Tensor from tokenizers import Tokenizer from transformers import ( AutoTokenizer, LlamaTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, DataCollatorWithPadding, ) from .data.lfqa import load_lfqa from .data.essays import load_essays from .data.wikitext import load_wikitext MAX_GENERATIONS = int(10000) # Hardcoded max length to avoid infinite loop def load_model(args): """Load and return the model and tokenizer""" args.is_seq2seq_model = any( [(model_type in args.model_name_or_path) for model_type in ["t5", "T0"]] ) args.is_decoder_only_model = any( [(model_type in args.model_name_or_path) for model_type in ["gpt", "opt", "bloom", "llama"]] ) if args.is_seq2seq_model: model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path) elif args.is_decoder_only_model: if args.load_fp16: model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=torch.float16, device_map="auto" ) else: model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) else: raise ValueError(f"Unknown model type: {args.model_name_or_path}") if args.use_gpu: device = "cuda" if torch.cuda.is_available() else "cpu" if args.load_fp16: pass else: model = model.to(device) else: device = "cpu" model.eval() if args.is_decoder_only_model: padding_side = "left" else: raise NotImplementedError( "Need to check how to handle padding for seq2seq models when calling generate" ) if "llama" in args.model_name_or_path: tokenizer = LlamaTokenizer.from_pretrained( args.model_name_or_path, padding_side=padding_side ) model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk model.config.bos_token_id = 1 model.config.eos_token_id = 2 else: tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, padding_side=padding_side ) args.model_max_length = model.config.max_position_embeddings return model, tokenizer, device def add_idx(example, idx): example.update({"idx": idx}) return example def load_hf_dataset(args): dataset_name, dataset_config_name = args.dataset_name, args.dataset_config_name if dataset_name == "lfqa": dataset = load_lfqa(args) args.__dict__.update( { "truncate_input_for_prompt": False, "input_col_name": "prefix", "ref_output_col_name": "gold_completion", } ) # other args set within the load_lfqa function elif dataset_name == "wikitext": dataset = load_wikitext(args) args.__dict__.update( { "truncate_input_for_prompt": True, "input_col_name": "text", "ref_output_col_name": None, } ) # other args set within the load_wikitext function elif dataset_name == "essays": dataset = load_essays(args) args.__dict__.update( { "truncate_input_for_prompt": False, "input_col_name": "instructions", "ref_output_col_name": "essays", } ) elif dataset_name == "cml_pile": subsets = [dataset_config_name] dataset = load_dataset( "./data/cml_pile.py", subsets=subsets, streaming=args.stream_dataset, split=None, ignore_verifications=True, )[args.dataset_split] args.__dict__.update( { "truncate_input_for_prompt": True, "input_col_name": "text", "ref_output_col_name": None, } ) else: dataset = load_dataset( dataset_name, dataset_config_name, split=args.dataset_split, streaming=args.stream_dataset, ) if "c4" in dataset_name: args.__dict__.update( { "truncate_input_for_prompt": True, "input_col_name": "text", "ref_output_col_name": None, } ) args.columns_to_remove = list( set(args.columns_to_remove + ["text", "timestamp", "url"]) ) elif "pile" in dataset_name: args.__dict__.update( { "truncate_input_for_prompt": True, "input_col_name": "text", "ref_output_col_name": None, } ) args.columns_to_remove = list(set(args.columns_to_remove + ["text", "meta"])) else: raise NotImplementedError( f"Dataset {dataset_name} not yet supported. Please add specs to load_hf_dataset function." ) # add index to each row of dataset indexed_dataset = dataset.map(add_idx, batched=False, with_indices=True) # shuffle the first shuffle_buffer_size rows of streaming dataset, or whole dataset if not streaming # and take/select only the first n rows of the dataset (which caps the total number of pipeline iters possible) if isinstance(indexed_dataset, IterableDataset): shuffled_dataset = ( indexed_dataset.shuffle(seed=args.shuffle_seed, buffer_size=args.shuffle_buffer_size) if args.shuffle_dataset else indexed_dataset ) limited_dataset = ( shuffled_dataset.take(args.limit_indices) if args.limit_indices is not None else shuffled_dataset ) else: shuffled_dataset = ( indexed_dataset.shuffle(seed=args.shuffle_seed) if args.shuffle_dataset else indexed_dataset ) limited_dataset = ( shuffled_dataset.select(range(args.limit_indices)) if args.limit_indices is not None else shuffled_dataset ) if args.limit_indices is None: try: args.limit_indices = len(limited_dataset) except Exception as e: # can't infer length of dataset, probably because it's an IterableDataset pass return limited_dataset def check_input_lengths( example, min_sample_len=0, min_prompt_len=0, min_completion_len=0, max_input_len=None, max_new_tokens=None, ): orig_sample_length = example["orig_sample_length"] prompt_length = example["prompt_length"] real_completion_length = example["baseline_completion_length"] if max_input_len is not None: assert ( max_new_tokens is not None ), "need to specify max_new_tokens if max_input_length is specified" conds = all( [ orig_sample_length >= min_sample_len, prompt_length >= min_prompt_len, real_completion_length >= min_completion_len, ( ((prompt_length + max_new_tokens) <= max_input_len) if max_input_len is not None else True ), ] ) return conds def check_output_lengths(example, min_output_len=0): # FIXME, maybe should check baseline completion length too no_wm_output_len = example["no_wm_output_length"] w_wm_output_len = example["w_wm_output_length"] conds = all( [ no_wm_output_len >= min_output_len, w_wm_output_len >= min_output_len, ] ) return conds def tokenize_and_truncate( example: dict, input_col_name: str = "text", completion_length: int = None, prompt_length: int = None, hf_model_name: str = None, tokenizer=None, truncate_left=False, model_max_length=None, ): """take hf dataset entry and preprocess it for completion by a model""" assert hf_model_name is not None, "need model name to know whether to adjust wrt special tokens" assert input_col_name in example, f"expects {input_col_name} field to be present" # tokenize inputs_ids = tokenizer(example[input_col_name], return_tensors="pt")["input_ids"] example.update({"untruncated_inputs": inputs_ids}) if truncate_left: # truncate left inputs_ids = inputs_ids[:, -model_max_length:] if example["untruncated_inputs"].shape != inputs_ids.shape: print( "Input too long for model! ", "Left truncating under assumption that this is the prompt+output ", "to be fed to the *oracle* model", ) example.update({"untruncated_inputs": inputs_ids}) if (completion_length is not None) and (prompt_length is None): # leave at least one token as prefix # FIXME I think plus 1 since 0 is start tok slice_length = min(inputs_ids.shape[1] - 1, completion_length) elif (prompt_length is not None) and (completion_length is None): desired_comp_len = (inputs_ids.shape[1] - 1) - prompt_length slice_length = desired_comp_len if desired_comp_len > 0 else 0 else: raise ValueError( ( f"Can only tokenize and truncate based on either the desired prompt length or desired completion length,", f" but got completion_length:{completion_length},prompt_length:{prompt_length}", ) ) # truncate inputs_ids = inputs_ids[:, : inputs_ids.shape[1] - slice_length] # logic depending on special tokens for the model if "t5" in hf_model_name or "T0" in hf_model_name: inputs_ids[0, -1] = 1 # else: pass example.update({"input_ids": inputs_ids}) return example def tokenize_only( example: dict, input_col_name: str = "text", ref_output_col_name: str = None, tokenize_ref_output: bool = False, hf_model_name: str = None, tokenizer=None, model_max_length=None, ): """take hf dataset entry and preprocess it for completion by a model (but don't truncate) where the dataset optionally has a secondary column that is the reference output to be scored against""" """take hf dataset entry and preprocess it for completion by a model""" assert hf_model_name is not None, "need model name to know whether to adjust wrt special tokens" assert input_col_name in example, f"expects {input_col_name} field to be present" if ref_output_col_name is not None: assert ref_output_col_name in example, f"expects {ref_output_col_name} field to be present" # tokenize input input_ids = tokenizer( example[input_col_name], return_tensors="pt", truncation=True, max_length=model_max_length )["input_ids"] example.update({"input_ids": input_ids}) if tokenize_ref_output: # NOTE not sure this logic is useful/required if ref_output_col_name is not None: # tokenize ref output ref_output_ids = tokenizer( example[ref_output_col_name], return_tensors="pt", truncation=True, max_length=model_max_length, )["input_ids"] tokd_input_len, tokd_ref_output_length = input_ids.shape[1], ref_output_ids.shape[1] if tokd_input_len + tokd_ref_output_length > model_max_length: # truncate the ref output original_ref_output_len = tokd_ref_output_length ref_output_ids = ref_output_ids[:, : model_max_length - tokd_input_len] if original_ref_output_len != ref_output_ids.shape[1]: print( "Right truncating output, input+ref output too long for model. " "Note, since this is generation time truncating the reference doesn't affect anything really." ) example.update({"ref_output_ids": ref_output_ids}) # logic depending on special tokens for the model if "t5" in hf_model_name or "T0" in hf_model_name: raise NotImplementedError("T5 style model not yet supported") return example def tokenize_for_generation( example: dict, max_new_tokens: int = None, min_prompt_tokens: int = None, hf_model_name: str = None, tokenizer: Tokenizer = None, args: dict = None, ): # preprocessing, generation & scoring assert isinstance(example, dict), "Expect no batch dimension currently!" if not args.truncate_input_for_prompt: tokenize_ref_output = True # NOTE, note really sure how necessary this is # preprocess for model generation/completion example = tokenize_only( example, input_col_name=args.input_col_name, ref_output_col_name=args.ref_output_col_name, hf_model_name=hf_model_name, tokenizer=tokenizer, model_max_length=args.model_max_length, tokenize_ref_output=tokenize_ref_output, ) # Parse the results of tokenization. Simple, since # the prompt and baseline completion are from the raw text re_decoded_input = example[args.input_col_name] decoded_baseline_completion = example[args.ref_output_col_name] prompt_len = example["input_ids"].shape[1] baseline_completion_len = example["ref_output_ids"].shape[1] full_sample_len = prompt_len + baseline_completion_len # for now, remove this here, since it's not used downstream example.pop("ref_output_ids") else: # preprocess for model generation/completion example = tokenize_and_truncate( example, completion_length=max_new_tokens, prompt_length=min_prompt_tokens, hf_model_name=hf_model_name, tokenizer=tokenizer, ) # Logic to parse the results of tokenzation and splitting to # construct string versions of the prompt and baseline completion inputs = example["input_ids"] prompt_len = inputs.shape[1] # for isolating the "gold" baseline completion untruncated_inputs = example.pop("untruncated_inputs") full_sample_len = untruncated_inputs.shape[1] # decode the preprocessed input to store for audit re_decoded_input = tokenizer.batch_decode(inputs, skip_special_tokens=True)[0] # also decode the original suffix of the input for audit as the baseline baseline_completion_tokens = untruncated_inputs[:, inputs.shape[-1] :] decoded_baseline_completion = tokenizer.batch_decode( baseline_completion_tokens, skip_special_tokens=True )[0] baseline_completion_len = full_sample_len - prompt_len example.update( { "truncated_input": re_decoded_input, "baseline_completion": decoded_baseline_completion, "orig_sample_length": full_sample_len, "prompt_length": prompt_len, "baseline_completion_length": baseline_completion_len, } ) return example def collate_batch(input_ids: list, collator: DataCollatorWithPadding = None): """collate batch of input_ids into a padded batch of tensors""" assert ( input_ids[0].shape[0] == 1 and input_ids[0].shape[1] > 0 ), "expecting batch dimension of each tensor to be 1" # remove batch dimension for each tensor input_ids = [x.squeeze(0) for x in input_ids] return collator({"input_ids": input_ids})["input_ids"] def generate( examples, data_collator=None, generate_without_watermark=None, generate_with_watermark=None, watermark_processor=None, tokenizer=None, device=None, args=None, ): input_ids = collate_batch(input_ids=examples["input_ids"], collator=data_collator).to(device) with torch.no_grad(): if args.generation_seed is not None: torch.manual_seed(args.generation_seed) output_without_watermark = generate_without_watermark(input_ids=input_ids) if args.generation_seed is not None: torch.manual_seed(args.generation_seed) output_with_watermark = generate_with_watermark(input_ids=input_ids) if args.is_decoder_only_model: # need to isolate the newly generated tokens output_without_watermark = output_without_watermark[:, input_ids.shape[-1] :] output_with_watermark = output_with_watermark[:, input_ids.shape[-1] :] decoded_output_without_watermark = tokenizer.batch_decode( output_without_watermark, skip_special_tokens=True ) decoded_output_with_watermark = tokenizer.batch_decode( output_with_watermark, skip_special_tokens=True ) examples.update( { "no_wm_output": decoded_output_without_watermark, "w_wm_output": decoded_output_with_watermark, "no_wm_output_length": (output_without_watermark != tokenizer.pad_token_id) .sum(dim=-1) .tolist(), "w_wm_output_length": (output_with_watermark != tokenizer.pad_token_id) .sum(dim=-1) .tolist(), } ) if watermark_processor.spike_entropies is not None: examples["spike_entropies"] = watermark_processor._get_and_clear_stored_spike_ents() examples["spike_entropies"] = [ ents[:num_toks] for ents, num_toks in zip(examples["spike_entropies"], examples["w_wm_output_length"]) ] return examples