Spaces:
No application file
No application file
# Basic imports | |
import sys | |
import os | |
import argparse | |
from typing import List, Iterable, Optional | |
from functools import partial | |
import time | |
from tqdm import tqdm | |
import random | |
import math | |
from statistics import mean | |
import numpy as np | |
import torch | |
from torch import Tensor | |
from tokenizers import Tokenizer | |
import wandb | |
import matplotlib.pyplot as plt | |
# cache path before HF imports just for kicks | |
# bc I don't really know when this is pulled by the library | |
# TODO change to passing as an arg to the model load fn | |
USER = "jkirchen" | |
# Huggingface cache | |
HF_HOME=f"/cmlscratch/{USER}/.cache/huggingface" | |
# HF_HOME=f"/scratch0/{USER}/.cache/huggingface" | |
# HF_HOME=f"/scratch1/{USER}/.cache/huggingface" | |
os.environ["HF_HOME"] = HF_HOME | |
print(os.environ["HF_HOME"]) | |
# HF classses | |
from transformers import (AutoTokenizer, | |
AutoModelForSeq2SeqLM, | |
AutoModelForCausalLM, | |
LogitsProcessorList) | |
from datasets import load_dataset, Dataset | |
# watermarking micro lib | |
from watermark import (BlacklistLogitsProcessor, | |
add_idx, | |
check_input_lengths, | |
check_output_lengths, | |
tokenize_for_generation, | |
generate_completions, | |
evaluate_generation_fluency) | |
# better bool flag type for argparse | |
from submitit_utils import str2bool | |
# some file i/o helpers | |
from io_utils import write_jsonlines, write_json, read_jsonlines, read_json | |
def main(args): | |
########################################################################### | |
# Start logging | |
########################################################################### | |
if not args.no_wandb: | |
# storing slurm info to be sent to wandb to allow auditing logfiles later | |
args.SLURM_JOB_ID = os.getenv("SLURM_JOB_ID") | |
args.SLURM_ARRAY_JOB_ID = os.getenv("SLURM_ARRAY_JOB_ID") | |
args.SLURM_ARRAY_TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID") | |
# start a new wandb run to track this experiment, will send data to it later | |
run = wandb.init( | |
# set the wandb project where this run will be logged | |
project=args.wandb_project, | |
entity=args.wandb_entity, | |
name=args.run_name, | |
# track hyperparameters and run metadata | |
config=args | |
) | |
print(f"Output dir for this run: {args.output_dir}") | |
# notify if exists | |
if os.path.exists(args.output_dir): | |
print(f"Output dir for this run already exists!") | |
print(f"Contents: {sorted(os.listdir(args.output_dir))}") | |
else: | |
# create the output dir where run artifacts are stored | |
os.makedirs(args.output_dir) | |
########################################################################### | |
# Instantiate model and tokenizer | |
########################################################################### | |
hf_model_name = args.model_name | |
if "t5" in hf_model_name or "T0" in hf_model_name: | |
model = AutoModelForSeq2SeqLM.from_pretrained(hf_model_name) | |
else: | |
model = AutoModelForCausalLM.from_pretrained(hf_model_name) | |
tokenizer = AutoTokenizer.from_pretrained(hf_model_name) | |
# defaults to device 0 | |
# will need to use 'parallelize' for multi-gpu sharding | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
model.eval() | |
########################################################################### | |
# Load the dataset | |
########################################################################### | |
dataset_name, dataset_config_name = args.dataset_name, args.dataset_config_name | |
if dataset_name == "cml_pile": | |
subsets = [dataset_config_name] | |
dataset = load_dataset("input/cml_pile.py", | |
subsets=subsets, | |
streaming=True, | |
split=None, | |
ignore_verifications=True)["train"] | |
else: | |
dataset = load_dataset(dataset_name, dataset_config_name, split="train", streaming=True) | |
# log an example | |
ds_iterator = iter(dataset) | |
idx = 75 # if this is c4, it's the schumacher example lol | |
i = 0 | |
while i < idx: | |
next(ds_iterator) | |
i += 1 | |
example = next(ds_iterator) | |
print(example) | |
########################################################################### | |
# Construct the blacklist processor/sampler | |
########################################################################### | |
all_token_ids = list(tokenizer.get_vocab().values()) | |
vocab_size = len(all_token_ids) | |
print(f"Vocabulary size: {vocab_size}") | |
max_new_tokens = args.max_new_tokens | |
min_prompt_tokens = args.min_prompt_tokens | |
init_seed = args.initial_seed | |
dyna_seed=args.dynamic_seed # type not value | |
bl_proportion = args.bl_proportion | |
bl_logit_bias = args.bl_logit_bias | |
bl_type = args.bl_type | |
n_beams = args.num_beams | |
early_stopping = args.early_stopping | |
no_repeat_ngram_size = args.no_repeat_ngram_size | |
store_bl_ids = args.store_bl_ids | |
store_spike_ents = args.store_spike_ents | |
bl_processor = BlacklistLogitsProcessor(bad_words_ids=None, | |
store_bl_ids=store_bl_ids, | |
store_spike_ents=store_spike_ents, | |
eos_token_id=tokenizer.eos_token_id, | |
vocab=all_token_ids, | |
vocab_size=vocab_size, | |
bl_proportion=bl_proportion, | |
bl_logit_bias=bl_logit_bias, | |
bl_type=bl_type, | |
initial_seed=init_seed, | |
dynamic_seed=dyna_seed) | |
logit_processor_lst = LogitsProcessorList([bl_processor]) | |
# Greedy and basic beam search, default | |
gen_kwargs = dict( | |
max_new_tokens=max_new_tokens, | |
num_beams=n_beams, | |
) | |
if n_beams > 1: | |
# these are only for beam search repetition correction | |
if no_repeat_ngram_size > 0: | |
gen_kwargs.update(dict(no_repeat_ngram_size=no_repeat_ngram_size)) | |
gen_kwargs.update(dict(early_stopping=early_stopping)) | |
if args.use_sampling: | |
gen_kwargs.update(dict(do_sample=True, | |
top_k=0, | |
temperature=args.sampling_temp)) | |
if args.all_gas_no_eos: | |
gen_kwargs.update(dict(suppress_tokens=[tokenizer.eos_token_id])) | |
generate_without_blacklist = partial( | |
model.generate, | |
**gen_kwargs | |
) | |
generate_with_blacklist = partial( | |
model.generate, | |
logits_processor=logit_processor_lst, | |
**gen_kwargs | |
) | |
########################################################################### | |
# Construct the generation and measurement pipeline (lazy) | |
# that pulls from the streaming dataset, applies the generations map funcs | |
########################################################################### | |
# Set up the pipeline functions | |
if "c4" in dataset_name: | |
columns_to_remove = ["text","timestamp","url"] | |
else: | |
columns_to_remove = [] | |
# Construct the data filtering/sampling scheme partials | |
token_kwargs = dict( | |
hf_model_name=hf_model_name, | |
tokenizer=tokenizer, | |
model=model, | |
) | |
if args.input_truncation_strategy == "prompt_length": | |
token_kwargs.update(dict(min_prompt_tokens=min_prompt_tokens)) | |
elif args.input_truncation_strategy == "completion_length": | |
token_kwargs.update(dict(max_new_tokens=max_new_tokens)) | |
else: | |
ValueError(f"Unknown input truncation strategy {args.input_truncation_strategy}") | |
tokenize_prompts = partial( | |
tokenize_for_generation, | |
**token_kwargs | |
) | |
input_check_kwargs = dict( | |
# min_sample_len = min_prompt_tokens + max_new_tokens, | |
min_sample_len = args.min_sample_tokens, # first line is a bug sometimes with large amounts | |
) | |
if args.input_filtering_strategy == "prompt_length": | |
input_check_kwargs.update(dict(min_prompt_len = min_prompt_tokens, | |
min_completion_len = 0)) | |
elif args.input_filtering_strategy == "completion_length": | |
input_check_kwargs.update(dict(min_prompt_len = 0, | |
min_completion_len = max_new_tokens)) | |
elif args.input_filtering_strategy == "prompt_and_completion_length": | |
input_check_kwargs.update(dict(min_prompt_len = min_prompt_tokens, | |
min_completion_len = max_new_tokens)) | |
else: | |
ValueError(f"Unknown input filtering strategy {args.input_filtering_strategy}") | |
input_check = partial( | |
check_input_lengths, | |
**input_check_kwargs | |
) | |
if args.output_filtering_strategy == "max_new_tokens": | |
output_kwargs = dict(min_output_len = max_new_tokens) | |
elif args.output_filtering_strategy == "no_filter": | |
output_kwargs = dict(min_output_len = 0) | |
else: | |
ValueError(f"Unknown output filtering strategy {args.output_filtering_strategy}") | |
output_check = partial( | |
check_output_lengths, | |
**output_kwargs | |
) | |
gen_completions = partial( | |
generate_completions, | |
max_new_tokens=max_new_tokens, | |
hf_model_name=hf_model_name, | |
tokenizer=tokenizer, | |
model=model, | |
no_bl_partial=generate_without_blacklist, | |
w_bl_partial=generate_with_blacklist, | |
bl_processor_list=logit_processor_lst, | |
) | |
########################################################################### | |
# Compose/apply the pipeline steps | |
########################################################################### | |
# Apply the pipeline operations to the dataset | |
indexed_dataset = dataset.map(add_idx, batched=False, with_indices=True) | |
# shuffled the first shuffle_buffer_size rows of the (streaming) dataset | |
if args.shuffle_dataset: | |
shuffled_dataset = indexed_dataset.shuffle(seed=args.shuffle_seed, | |
buffer_size=args.shuffle_buffer_size) | |
else: | |
shuffled_dataset = indexed_dataset | |
# tokenize and truncate the row inputs to create prompts according to the strategy spec'd above | |
tokenized_and_truncated_dataset = shuffled_dataset.map(tokenize_prompts, | |
batched=False, | |
with_indices=True) | |
# filter the rows of the dataset based on length checks for the tokenized prompts and baseline completions | |
input_length_filtered_dataset = tokenized_and_truncated_dataset.filter(input_check, | |
batched=False, | |
with_indices=True) | |
# perform generation by calling the models | |
columns_to_remove += ["inputs", "untruncated_inputs"] # these are now materialized and must be dropped externally | |
generations_dataset = input_length_filtered_dataset.map(gen_completions, | |
batched=False, | |
with_indices=True, | |
remove_columns=columns_to_remove) | |
# # filter the dataset a last time based on the lengths of the outputs of the model | |
# output_length_filtered_dataset = generations_dataset.filter(output_check, | |
# batched=False, | |
# with_indices=True) | |
########################################################################### | |
# Main loop - actually executes the generation pipeline. | |
# and accumulates the result rows in a list, assumes list is "small"-ish | |
# and we aren't accumulating any tensors or other memory hogging artifacts | |
########################################################################### | |
if not args.load_prev_generations: | |
processed_examples = [] | |
ds_iterator = iter(generations_dataset) | |
i = 0 | |
while i < args.limit_indices: | |
ex = next(ds_iterator) | |
# log basics to stdout | |
print(f"#"*80) | |
print(f"dataset index: {ex['idx']}") | |
print(f"orig_sample_length: {ex['orig_sample_length']}") | |
print(f"prompt_length: {ex['prompt_length']}") | |
print(f"real_completion_length: {ex['real_completion_length']}") | |
print(f"no_bl_num_tokens_generated: {ex['no_bl_num_tokens_generated']}") | |
print(f"w_bl_num_tokens_generated: {ex['w_bl_num_tokens_generated']}") | |
print(f"\ntruncated_input: ") | |
print(ex["truncated_input"]) | |
print(f"\nbaseline_completion: ") | |
print(ex["baseline_completion"]) | |
print(f"\nno_bl_output: ") | |
print(ex["no_bl_output"]) | |
print(f"\nw_bl_output: ") | |
print(ex["w_bl_output"]) | |
print(f"\nno_bl_gen_time: ") | |
print(ex["no_bl_gen_time"]) | |
print(f"\nno_bl_sec_per_tok: ") | |
print(ex["no_bl_sec_per_tok"]) | |
print(f"\nno_bl_tok_per_sec: ") | |
print(ex["no_bl_tok_per_sec"]) | |
print(f"\nw_bl_gen_time: ") | |
print(ex["w_bl_gen_time"]) | |
print(f"\nw_bl_sec_per_tok: ") | |
print(ex["w_bl_sec_per_tok"]) | |
print(f"\nw_bl_tok_per_sec: ") | |
print(ex["w_bl_tok_per_sec"]) | |
processed_examples.append(ex) | |
if output_check(ex) == True: | |
i += 1 | |
else: | |
print(f"\nGeneration too short, saving outputs, but not incrementing counter...\n", | |
f"{i} of {len(processed_examples)} rows were satisfactory so far", | |
f"current generation overhead ratio: {round(len(processed_examples)/(i+1), 3)}", | |
f"completed {round(i/args.limit_indices, 2)} of total") | |
print(f"#"*80, | |
f"\nGeneration output length check overhead was num rows processed={len(processed_examples)}", | |
f"for {args.limit_indices} samples. Ratio: {round(len(processed_examples)/args.limit_indices, 3)}") | |
########################################################################### | |
# Generation jsonl dumping/loading | |
########################################################################### | |
gen_table_meta_path = f"{args.output_dir}/gen_table_meta.json" | |
gen_table_path = f"{args.output_dir}/gen_table.jsonl" | |
safe_gen_table_path = f"{args.output_dir}/gen_table_safe.jsonl" | |
args.gen_table_already_existed = False | |
if not args.load_prev_generations: | |
if os.path.exists(gen_table_path): | |
print(f"Found existing generation files at this output dir: {args.output_dir}") | |
print(f"Writing generations at alternate, safe path and exiting. Note! this only works once. " | |
f"Safe version will get overwritten next time ... ") | |
gen_table_path = f"{args.output_dir}/gen_table_safe.jsonl" | |
args.gen_table_already_existed = True | |
gen_table_meta = args.__dict__ | |
gen_table = processed_examples | |
write_jsonlines(gen_table, gen_table_path) | |
write_json(gen_table_meta,gen_table_meta_path,indent=4) | |
if args.gen_table_already_existed: | |
# finish the wandb run | |
if not args.no_wandb: run.finish() | |
return # from main, for safety | |
else: | |
print(f"Loading previously generated outputs for evaluation via oracle model and metrics...") | |
assert os.path.exists(gen_table_meta_path), f"failed file check for prev generations metadata json file: {gen_table_meta_path}" | |
assert os.path.exists(gen_table_path), f"failed file check for prev generations jsonl file: {gen_table_path}" | |
curr_gen_table_meta = args.__dict__.copy() | |
prev_gen_table_meta = read_json(gen_table_meta_path) | |
assert not prev_gen_table_meta["gen_table_already_existed"], f"failed for safety bc 'gen_table_already_existed' was true in the metadata file in this dir, indicating a possible issue" | |
assert not os.path.exists(safe_gen_table_path), f"failed for safety bc there is a secondary 'safe' marked file in this dir indicating a possible issue" | |
params_to_ignore = ["load_prev_generations","SLURM_JOB_ID","SLURM_ARRAY_JOB_ID","SLURM_ARRAY_TASK_ID"] | |
for k in params_to_ignore: | |
del curr_gen_table_meta[k] | |
del prev_gen_table_meta[k] | |
assert curr_gen_table_meta == prev_gen_table_meta, "failed safety check that current script params equal the params for the prev generations being loaded" | |
# gen_table_meta = argparse.Namespace(**args.__dict__) | |
gen_table_meta = args | |
gen_table = [ex for ex in read_jsonlines(gen_table_path)] | |
if args.generate_only: | |
# finish the wandb run | |
if not args.no_wandb: run.finish() | |
return # early exit, will reload later for ppl scoring | |
# Create a new dataset object either from the loop over examples | |
# or from the reloaded json lines | |
# gen_table_ds = Dataset.from_generator(ex for ex in gen_table) # hack since from_list is newer, and had 2.4.0 | |
gen_table_ds = Dataset.from_list(gen_table) | |
########################################################################### | |
# Perplexity (PPL) evaluation | |
# which is a separate step partially bc it requires a different model on gpu | |
########################################################################### | |
# Load the oracle model for PPL measurement | |
# Assume on single GPU and need to free orig model memory for oracle model | |
if model is not None: | |
model = model.to(torch.device("cpu")) | |
del model | |
oracle_model_name = args.oracle_model_name | |
print(f"Loading oracle model: {oracle_model_name}") | |
oracle_tokenizer = AutoTokenizer.from_pretrained(oracle_model_name) | |
oracle_model = AutoModelForCausalLM.from_pretrained(oracle_model_name).to(device) | |
oracle_model.eval() | |
# construct fluency/ppl partial | |
eval_gen_metrics = partial( | |
evaluate_generation_fluency, | |
oracle_model_name=oracle_model_name, | |
oracle_model=oracle_model, | |
oracle_tokenizer=oracle_tokenizer | |
) | |
print(f"Computing metrics on model generations: {gen_table_ds}") | |
gen_table_w_metrics_ds = gen_table_ds.map(eval_gen_metrics, batched=False, with_indices=True) | |
print(f"#"*80) | |
print(f"baseline avg PPL: {mean(gen_table_w_metrics_ds['baseline_ppl'])}") | |
print(f"baseline avg loss: {mean(gen_table_w_metrics_ds['baseline_loss'])}") | |
print(f"no_bl avg PPL: {mean(gen_table_w_metrics_ds['no_bl_ppl'])}") | |
print(f"no_bl avg loss: {mean(gen_table_w_metrics_ds['no_bl_loss'])}") | |
print(f"w_bl avg PPL: {mean(gen_table_w_metrics_ds['w_bl_ppl'])}") | |
print(f"w_bl avg loss: {mean(gen_table_w_metrics_ds['w_bl_loss'])}") | |
# clear the model just for fun | |
oracle_model = oracle_model.to(torch.device("cpu")) | |
del oracle_model | |
gen_table_w_metrics_path = f"{args.output_dir}/gen_table_w_metrics.jsonl" | |
if os.path.exists(gen_table_w_metrics_path): | |
print(f"Found existing generation files with metrics added at this output dir. Overwriting anyway :\ -> {args.output_dir}") | |
gen_table_w_metrics_lst = [ex for ex in gen_table_w_metrics_ds] | |
write_jsonlines(gen_table_w_metrics_lst, gen_table_w_metrics_path) | |
# finish the wandb run | |
run.finish() | |
return | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Run watermarked huggingface LM generation pipeline") | |
parser.add_argument( | |
"--model_name", | |
type=str, | |
default="facebook/opt-2.7b", | |
help="Main model, path to pretrained model or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--dataset_name", | |
type=str, | |
default="c4", | |
help="The name of the dataset to use (via the datasets library).", | |
) | |
parser.add_argument( | |
"--dataset_config_name", | |
type=str, | |
default="realnewslike", | |
help="The configuration name of the dataset to use (via the datasets library).", | |
) | |
parser.add_argument( | |
"--shuffle_dataset", | |
type=str2bool, | |
default=False, | |
help="Whether to shuffle the dataset before sampling.", | |
) | |
parser.add_argument( | |
"--shuffle_seed", | |
type=int, | |
default=1234, | |
help="The seed to use for dataset shuffle op.", | |
) | |
parser.add_argument( | |
"--shuffle_buffer_size", | |
type=int, | |
default=10_000, | |
help="The buffer size to use for dataset shuffle op - takes n rows first, then shuffles those indices", | |
) | |
parser.add_argument( | |
"--max_new_tokens", | |
type=int, | |
default=100, | |
help="The number of tokens to generate using the model, and the num tokens removed from real text sample", | |
) | |
parser.add_argument( | |
"--min_prompt_tokens", | |
type=int, | |
default=50, # 500 | |
help="The number of examples (first N) to process from the dataset.", | |
) | |
parser.add_argument( | |
"--min_sample_tokens", | |
type=int, | |
default=0, | |
help="The the minimum length of raw prompt samples to consider.", | |
) | |
parser.add_argument( | |
"--limit_indices", | |
type=int, | |
default=5, # 500 | |
help="The number of examples (first N) to process from the dataset.", | |
) | |
parser.add_argument( | |
"--input_truncation_strategy", | |
type=str, | |
default="completion_length", | |
choices=["completion_length", "prompt_length"], | |
help="The strategy to use when tokenizing and truncating raw inputs to make prompts.", | |
) | |
parser.add_argument( | |
"--input_filtering_strategy", | |
type=str, | |
default="completion_length", | |
choices=["completion_length", "prompt_length", "prompt_and_completion_length"], | |
help="The strategy to use when tokenizing and truncating raw inputs to make prompts.", | |
) | |
parser.add_argument( | |
"--output_filtering_strategy", | |
type=str, | |
default="no_filter", | |
choices=["no_filter", "max_new_tokens"], | |
help=(f"The strategy to use when filtering/skipping rows if the model didn't ", | |
f"generate enough tokens to facilitate analysis.") | |
) | |
parser.add_argument( | |
"--initial_seed", | |
type=int, | |
default=1234, | |
help=("The initial seed to use in the blacklist randomization process.", | |
"Is unused if the process is markov generally. Can be None."), | |
) | |
parser.add_argument( | |
"--dynamic_seed", | |
type=str, | |
default="markov_1", | |
choices=[None, "initial", "markov_1"], | |
help="The seeding procedure to use when sampling the blacklist at each step.", | |
) | |
parser.add_argument( | |
"--bl_proportion", | |
type=float, | |
default=0.5, | |
help="The ratio of blacklist to whitelist tokens when splitting the vocabulary", | |
) | |
parser.add_argument( | |
"--bl_logit_bias", | |
type=float, | |
default=1.0, | |
help="The amount of bias (absolute) to add to the logits in the whitelist half of the vocabulary at every step", | |
) | |
parser.add_argument( | |
"--bl_type", | |
type=str, | |
default="soft", | |
choices=["soft", "hard"], | |
help="The type of blacklisting being performed.", | |
) | |
parser.add_argument( | |
"--num_beams", | |
type=int, | |
default=1, | |
help="The number of beams to use where '1' is no beam search.", | |
) | |
parser.add_argument( | |
"--no_repeat_ngram_size", | |
type=int, | |
default=0, | |
# default=8, | |
help="ngram size to force the model not to generate, can't be too small or model is handicapped, too large and blows up in complexity.", | |
) | |
parser.add_argument( | |
"--early_stopping", | |
type=str2bool, | |
default=False, | |
help="Whether to use early stopping, only for beam search.", | |
) | |
# parser.add_argument( | |
# "--hard_min_length", | |
# type=str2bool, | |
# default=False, | |
# help="Whether to use the min length logits processor to force the generations to be max_new_tokens.", | |
# ) | |
parser.add_argument( | |
"--oracle_model_name", | |
type=str, | |
default="EleutherAI/gpt-j-6B", | |
help="PPL scoring, or oracle model, path to pretrained model or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--no_wandb", | |
type=str2bool, | |
default=False, | |
help="Whether to log to wandb.", | |
) | |
parser.add_argument( | |
"--wandb_project", | |
type=str, | |
default="lm-blacklisting", | |
help="The name of the wandb project.", | |
) | |
parser.add_argument( | |
"--wandb_entity", | |
type=str, | |
default="jwkirchenbauer", | |
help="The wandb entity/user for the project.", | |
) | |
parser.add_argument( | |
"--run_name", | |
type=str, | |
default=None, | |
help="The unique name for the run.", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="./output", | |
help="The unique name for the run.", | |
) | |
parser.add_argument( | |
"--load_prev_generations", | |
type=str2bool, | |
default=False, | |
help=("Whether to run generations or load from a json lines in the output_dir. " | |
"If True, this file must exist and meta/args must match"), | |
) | |
parser.add_argument( | |
"--store_bl_ids", | |
type=str2bool, | |
default=False, | |
help=("Whether to store all the blacklists while generating with bl processor. "), | |
) | |
parser.add_argument( | |
"--store_spike_ents", | |
type=str2bool, | |
default=False, | |
help=("Whether to store the spike entropies while generating with bl processor. "), | |
) | |
parser.add_argument( | |
"--use_sampling", | |
type=str2bool, | |
default=False, | |
help=("Whether to perform sampling during generation. (non-greedy decoding)"), | |
) | |
parser.add_argument( | |
"--sampling_temp", | |
type=float, | |
default=0.7, | |
help="The temperature to use when generating using multinom sampling", | |
) | |
parser.add_argument( | |
"--generate_only", | |
type=str2bool, | |
default=False, | |
help=("Whether to only produce outputs and not evaluate anything like ppl"), | |
) | |
parser.add_argument( | |
"--all_gas_no_eos", | |
type=str2bool, | |
default=False, | |
help=("Whether to weight the EOS token as -inf"), | |
) | |
args = parser.parse_args() | |
main(args) | |