Spaces:
Runtime error
Runtime error
import argparse | |
from scrl.hill_climbing import DynamicRestartHCSC, PunktTokenizer, WhiteSpaceTokenizer | |
from scrl.config_hc import load_config | |
from scrl.rewards import load_rewards | |
from scrl import utils | |
import tqdm | |
from pathlib import Path | |
def run_on_dataset( | |
searcher, | |
dataset, | |
target_len, | |
target_ratio, | |
n_steps, | |
outpath, | |
): | |
outpath = Path(outpath) | |
start = 0 | |
if outpath.exists(): | |
for i, x in enumerate(utils.read_jsonl(outpath)): | |
start += 1 | |
passed = 0 | |
batches = utils.batchify(dataset, batch_size=4) | |
for batch in tqdm.tqdm(batches): | |
passed += len(batch) | |
if passed <= start: | |
continue | |
elif passed == start + len(batch): | |
print(f"starting at position {passed - len(batch)}") | |
sources = [x["text"] for x in batch] | |
if target_len is not None: | |
target_lens = [target_len for _ in batch] | |
else: | |
input_lens = [len(tokens) for tokens in searcher.tokenizer(sources)] | |
target_lens = [round(target_ratio * l) for l in input_lens] | |
print(input_lens) | |
print(target_lens) | |
states = searcher( | |
sources, | |
target_lens=target_lens, | |
n_steps=n_steps, | |
) | |
preds = [s["best_summary"] for s in states] | |
utils.write_jsonl(states, outpath, "a") | |
def main(args): | |
config = load_config(args) | |
print("DEVICE:", config.device) | |
objective = load_rewards(config) | |
tokenizer = WhiteSpaceTokenizer() if args.pretokenized else PunktTokenizer() | |
searcher = DynamicRestartHCSC(tokenizer, objective) | |
dataset = list(utils.read_jsonl(args.dataset)) | |
assert (args.target_len is None or args.target_ratio is None) | |
run_on_dataset( | |
searcher, | |
dataset, | |
args.target_len, | |
args.target_ratio, | |
args.steps, | |
args.output | |
) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", help="path to JSON config file", required=True) | |
parser.add_argument("--output", required=True) | |
parser.add_argument("--dataset", required=True) | |
parser.add_argument("--pretokenized", action="store_true") | |
parser.add_argument("--device", default="cuda") | |
parser.add_argument("--target-len", type=int, default=None) | |
parser.add_argument("--target-ratio", type=float, default=None) | |
parser.add_argument("--steps", default=1000, type=int) | |
main(load_config(parser.parse_args())) | |