Spaces:
Build error
Build error
import sys | |
import requests | |
import json | |
import pandas as pd | |
SOCIOFILLMORE_API = "http://127.0.0.1:5000" | |
AUTH_KEY = "3TrJ397oh#^" | |
def get_sample(s, dataset, n_samples, frame, construction, role, dependency): | |
s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset}) | |
r_q = s.get( | |
SOCIOFILLMORE_API + "/sample_frame", | |
params={ | |
"auth_key": AUTH_KEY, | |
"frame": frame, | |
"construction": construction, | |
"role": role, | |
"dependency": dependency, | |
"model": "lome_0shot", | |
"n": n_samples, | |
}, | |
) | |
data = json.loads(r_q.text) | |
rows_out = [] | |
for sent in data: | |
for fns in sent["fn_structures"]: | |
if fns["frame"] == frame: | |
target_roles = [r for r in fns["roles"] if r[0] == role] | |
if target_roles: | |
target_role = target_roles[0] | |
else: | |
continue | |
rows_out.append( | |
{ | |
"dataset": dataset, | |
"sentence": " ".join(sent["sentence"]), | |
"frame": frame, | |
"target": " ".join(fns["target"]["tokens_str"]), | |
"role_label": role, | |
"role_span": " ".join(target_role[1]["tokens_str"]), | |
"dependency": dependency, | |
} | |
) | |
return rows_out | |
def get_labels(s, dataset, frame): | |
s.get(SOCIOFILLMORE_API + "/switch_dataset", params={"dataset": dataset}) | |
r_q = s.get( | |
SOCIOFILLMORE_API + "/frame_freq", | |
params={ | |
"auth_key": AUTH_KEY, | |
"model": "lome_0shot", | |
"frames": frame, | |
"constructions": "", | |
"group_by_cat": "n", | |
"group_by_constr": "n", | |
"group_by_role_expr": 2, | |
"relative": "y", | |
"plot_over_days_post": "n", | |
}, | |
) | |
data = json.loads(r_q.text) | |
return {l.split("::")[2] for l in data["relevant_frame_counts"]["x"]} | |
def main(language): | |
s = requests.Session() | |
if language == "it": | |
print("Finding IT labels...") | |
labels_it = get_labels(s, "femicides/rai", "Killing") | |
sample_rows_it = [] | |
for label in sorted(labels_it): | |
if label == "_UNK_DEP": | |
continue | |
print(f"Label (IT): {label}") | |
sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Killer", label)) | |
sample_rows_it.extend(get_sample(s, "femicides/rai", 2, "Killing", "*", "Victim", label)) | |
df_samples_it = pd.DataFrame(sample_rows_it) | |
df_samples_it.to_csv("output/common/query_frame_samples/it_dep_samples.csv") | |
if language == "nl": | |
print("Finding NL labels...") | |
labels_nl = get_labels(s, "crashes/thecrashes", "Cause_harm") | |
sample_rows_nl = [] | |
for label in sorted(labels_nl): | |
if label == "_UNK_DEP": | |
continue | |
print(f"Label (NL): {label}") | |
sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Agent", label)) | |
sample_rows_nl.extend(get_sample(s, "crashes/thecrashes", 2, "Cause_harm", "*", "Victim", label)) | |
df_samples_nl = pd.DataFrame(sample_rows_nl) | |
df_samples_nl.to_csv("output/common/query_frame_samples/nl_dep_samples.csv") | |
if __name__ == "__main__": | |
main(language=sys.argv[1]) | |