Spaces:
Runtime error
Runtime error
File size: 2,251 Bytes
8655a4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
"""
Clean model judgment files.
"""
import argparse
import json
selected_models = [
"alpaca-13b",
"baize-v2-13b",
"chatglm-6b",
"claude-instant-v1",
"claude-v1",
"dolly-v2-12b",
"falcon-40b-instruct",
"fastchat-t5-3b",
"gpt-3.5-turbo",
"gpt-4",
"gpt4all-13b-snoozy",
"guanaco-33b",
"guanaco-65b",
"h2ogpt-oasst-open-llama-13b",
"koala-13b",
"llama-13b",
"mpt-30b-chat",
"mpt-30b-instruct",
"mpt-7b-chat",
"nous-hermes-13b",
"oasst-sft-4-pythia-12b",
"oasst-sft-7-llama-30b",
"palm-2-chat-bison-001",
"rwkv-4-raven-14b",
"stablelm-tuned-alpha-7b",
"tulu-30b",
"vicuna-13b-v1.3",
"vicuna-33b-v1.3",
"vicuna-7b-v1.3",
"wizardlm-13b",
"wizardlm-30b",
]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--infile", type=str)
args = parser.parse_args()
infile = args.infile
outfile = infile.replace(".jsonl", "_clean.jsonl")
raw_lines = open(infile).readlines()
rets = []
models = set()
visited = set()
for line in raw_lines:
obj = json.loads(line)
if "model_1" in obj: # pair
model = obj["model_1"]
key = (
obj["model_1"],
obj["model_2"],
obj["question_id"],
tuple(obj["judge"]),
)
else: # single
model = obj["model"]
key = (obj["model"], obj["question_id"], tuple(obj["judge"]))
if key in visited:
continue
visited.add(key)
if model not in selected_models:
continue
models.add(model)
rets.append(obj)
models = sorted(list(models))
missing_models = [x for x in selected_models if x not in models]
print(f"in models: {models}, number: {len(models)}")
print(f"missing models: {missing_models}")
print(f"#in: {len(raw_lines)}, #out: {len(rets)}")
rets.sort(
key=lambda x: (
x["model"] if "model" in x else x["model_1"],
x["question_id"],
x["turn"],
)
)
with open(outfile, "w") as fout:
for x in rets:
fout.write(json.dumps(x) + "\n")
|