import os import json import pandas as pd def main(input_json, input_txt, output_dir): meta_df = pd.read_csv("output/migration/split_data/split_dev10.texts.meta.csv") text_ids = meta_df["text_id"].to_list() with open(input_json, encoding="utf-8") as f: json_predictions = json.load(f) with open(input_txt, encoding="utf-8") as f: txt_predictions = f.read().split("\n\n") for t_id, json_p, txt_p in zip(text_ids, json_predictions, txt_predictions): if int(t_id) % 100 == 0: print(t_id) prediction_dir = f"{output_dir}/{t_id}" if not os.path.isdir(prediction_dir): os.makedirs(prediction_dir) prediction_file_json = f"{prediction_dir}/lome_{t_id}.comm.json" prediction_file_txt = f"{prediction_dir}/lome_{t_id}.comm.txt" with open(prediction_file_json, "w", encoding="utf-8") as f_out: json.dump([json_p], f_out) with open(prediction_file_txt, "w", encoding="utf-8") as f_out: f_out.write(txt_p + "\n\n") if __name__ == "__main__": # main( # input_json="output/migration/lome/lome_0shot/lome_lome_0shot_migration_all_tc.comm.json", # input_txt="output/migration/lome/lome_0shot/lome_lome_0shot_migration_all_tc.comm.txt", # output_dir="output/migration/lome/multilabel/lome_0shot/pavia" # ) # main( # input_json="output/migration/lome/lome_0shot/lome_lome_0shot_migration_all_best-truecase.comm.json", # input_txt="output/migration/lome/lome_0shot/lome_lome_0shot_migration_all_best-truecase.comm.txt", # output_dir="output/migration/lome/multilabel/lome_0shot/pavia" # ) # main( # input_json="output/migration/lome/lome_zs-tgt_ev-frm/data-in.concat.combined_zs_ev.tc_bilstm.json", # input_txt="output/migration/lome/lome_zs-tgt_ev-frm/data-in.concat.combined_zs_ev.tc_bilstm.txt", # output_dir="output/migration/lome/multilabel/lome_zs-tgt_ev_frm/pavia" # ) main( input_json="/home/gossminn/WorkSyncs/Code/fn-for-social-frames/output/migration/lome/lome_migration_concat.comm.json", input_txt="/home/gossminn/WorkSyncs/Code/fn-for-social-frames/output/migration/lome/lome_migration_concat.comm.txt", output_dir="output/migration/lome/multilabel/lome_0shot/pavia" )