import argparse from scrl.hill_climbing import DynamicRestartHCSC, PunktTokenizer, WhiteSpaceTokenizer from scrl.config_hc import load_config from scrl.rewards import load_rewards from scrl import utils import tqdm from pathlib import Path def run_on_dataset( searcher, dataset, target_len, target_ratio, n_steps, outpath, ): outpath = Path(outpath) start = 0 if outpath.exists(): for i, x in enumerate(utils.read_jsonl(outpath)): start += 1 passed = 0 batches = utils.batchify(dataset, batch_size=4) for batch in tqdm.tqdm(batches): passed += len(batch) if passed <= start: continue elif passed == start + len(batch): print(f"starting at position {passed - len(batch)}") sources = [x["text"] for x in batch] if target_len is not None: target_lens = [target_len for _ in batch] else: input_lens = [len(tokens) for tokens in searcher.tokenizer(sources)] target_lens = [round(target_ratio * l) for l in input_lens] print(input_lens) print(target_lens) states = searcher( sources, target_lens=target_lens, n_steps=n_steps, ) preds = [s["best_summary"] for s in states] utils.write_jsonl(states, outpath, "a") def main(args): config = load_config(args) print("DEVICE:", config.device) objective = load_rewards(config) tokenizer = WhiteSpaceTokenizer() if args.pretokenized else PunktTokenizer() searcher = DynamicRestartHCSC(tokenizer, objective) dataset = list(utils.read_jsonl(args.dataset)) assert (args.target_len is None or args.target_ratio is None) run_on_dataset( searcher, dataset, args.target_len, args.target_ratio, args.steps, args.output ) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--config", help="path to JSON config file", required=True) parser.add_argument("--output", required=True) parser.add_argument("--dataset", required=True) parser.add_argument("--pretokenized", action="store_true") parser.add_argument("--device", default="cuda") parser.add_argument("--target-len", type=int, default=None) parser.add_argument("--target-ratio", type=float, default=None) parser.add_argument("--steps", default=1000, type=int) main(load_config(parser.parse_args()))