Spaces:
Runtime error
Runtime error
import os | |
import argparse | |
from lats import run_lats | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--run_name", type=str, help="The name of the run") | |
parser.add_argument("--root_dir", type=str, | |
help="The root logging directory", default="root") | |
parser.add_argument("--dataset_path", type=str, | |
help="The path to the benchmark dataset", default="root") | |
parser.add_argument("--strategy", type=str, | |
help="Strategy: `simple`, `reflexion`") | |
parser.add_argument("--language", type=str, help="Strategy: `py` or `rs`") | |
parser.add_argument( | |
"--model", type=str, help="OpenAI models only for now. For best results, use GPT-4") | |
parser.add_argument("--pass_at_k", type=int, | |
help="Pass@k metric", default=1) | |
parser.add_argument("--max_iters", type=int, | |
help="The maximum number of self-improvement iterations", default=10) | |
parser.add_argument("--expansion_factor", type=int, | |
help="The expansion factor for the reflexion UCS and A* strategy", default=3) | |
parser.add_argument("--verbose", action='store_true', | |
help="To print live logs") | |
parser.add_argument("--instruction", type=str, | |
help="text string", default="") | |
parser.add_argument("--n_samples", type=int, | |
help="The number of nodes added during expansion", default=3) | |
parser.add_argument("--depth", type=int, | |
help="Tree depth", default=5) | |
# TODO: implement this | |
# parser.add_argument("--is_resume", action='store_true', help="To resume run") | |
# parser.add_argument("--resume_dir", type=str, help="If resume, the logging directory", default="") | |
args = parser.parse_args() | |
return args | |
def strategy_factory(strategy: str): | |
def kwargs_wrapper_gen(func, delete_keys=[]): | |
def kwargs_wrapper(**kwargs): | |
for key in delete_keys: | |
del kwargs[key] | |
return func(**kwargs) | |
return kwargs_wrapper | |
return kwargs_wrapper_gen(run_lats, delete_keys=[]) | |
def lats_main(args): | |
# check if the strategy is valid | |
run_strategy = strategy_factory(args.strategy) | |
# start the run | |
# evaluate with pass@k | |
x = run_strategy( | |
model_name=args.model, | |
language=args.language, | |
max_iters=args.max_iters, | |
verbose=args.verbose, | |
instruction=args.instruction, | |
n_samples=args.n_samples, | |
depth=args.depth | |
) | |
return x | |
def main(args): | |
lats_main(args) | |
if __name__ == "__main__": | |
args = get_args() | |
main(args) | |