Eran Fainman
commited on
Commit
·
071fa8f
1
Parent(s):
83fd180
Add application file
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import uuid
|
|
|
3 |
import pandas as pd
|
4 |
import streamlit as st
|
5 |
import huggingface_hub as hh
|
@@ -24,7 +25,6 @@ GROUPS_INFO_PATH = os.path.join(CACHE_PATH, "groups-info")
|
|
24 |
|
25 |
REQUIRED_COLUMNS = ['accommodation_id', 'user_id'] + [f'review_{i}' for i in range(1, 11)]
|
26 |
|
27 |
-
|
28 |
API = hh.HfApi(token=TOKEN)
|
29 |
|
30 |
|
@@ -54,12 +54,26 @@ def refresh_data():
|
|
54 |
refresh_data()
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def calculate_metrics(df_pred):
|
58 |
df_gt = pd.read_csv(os.path.join(GT_PATH, 'val_matches.csv'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
mrr10
|
61 |
-
precision10 = 0.2
|
62 |
-
return mrr10, precision10
|
63 |
|
64 |
|
65 |
def get_group_name_by_email(email):
|
@@ -148,7 +162,7 @@ def render():
|
|
148 |
submission_tab.markdown("**Group name:** " + group_name)
|
149 |
submission_tab.markdown("**Model name:** " + model_name)
|
150 |
submission_tab.markdown("**Revision:** " + str(revision) +
|
151 |
-
f"(out of {MAX_SUBMISSIONS} allowed submissions)")
|
152 |
|
153 |
submission_tab.write("### Submission results")
|
154 |
submission_tab.markdown("**MRR@10:** {:.4f}".format(mrr10))
|
|
|
1 |
import os
|
2 |
import uuid
|
3 |
+
import numpy as np
|
4 |
import pandas as pd
|
5 |
import streamlit as st
|
6 |
import huggingface_hub as hh
|
|
|
25 |
|
26 |
REQUIRED_COLUMNS = ['accommodation_id', 'user_id'] + [f'review_{i}' for i in range(1, 11)]
|
27 |
|
|
|
28 |
API = hh.HfApi(token=TOKEN)
|
29 |
|
30 |
|
|
|
54 |
refresh_data()
|
55 |
|
56 |
|
57 |
+
def get_match_index(row):
|
58 |
+
for i in range(1, 11):
|
59 |
+
if row['review_id'] == row[f'review_{i}']:
|
60 |
+
return i
|
61 |
+
return np.inf
|
62 |
+
|
63 |
+
|
64 |
def calculate_metrics(df_pred):
|
65 |
df_gt = pd.read_csv(os.path.join(GT_PATH, 'val_matches.csv'))
|
66 |
+
if len(df_pred) != len(df_gt):
|
67 |
+
raise Exception("Your predictions file should contain {} rows, only {} rows were found in the file".format(
|
68 |
+
len(df_gt), len(df_pred)
|
69 |
+
))
|
70 |
+
|
71 |
+
df_merged = pd.merge(df_gt, df_pred, how='left', on=['accommodation_id', 'user_id']).fillna('')
|
72 |
+
df_merged['match_index'] = df_merged.apply(get_match_index, axis=1)
|
73 |
+
df_merged['mrr10'] = df_merged['match_index'].apply(lambda x: 1/x)
|
74 |
+
df_merged['precision10'] = df_merged.apply(lambda x: 1 if x != np.inf else 0)
|
75 |
|
76 |
+
return df_merged['mrr10'].mean(), df_merged['precision10'].mean()
|
|
|
|
|
77 |
|
78 |
|
79 |
def get_group_name_by_email(email):
|
|
|
162 |
submission_tab.markdown("**Group name:** " + group_name)
|
163 |
submission_tab.markdown("**Model name:** " + model_name)
|
164 |
submission_tab.markdown("**Revision:** " + str(revision) +
|
165 |
+
f" (out of {MAX_SUBMISSIONS} allowed submissions)")
|
166 |
|
167 |
submission_tab.write("### Submission results")
|
168 |
submission_tab.markdown("**MRR@10:** {:.4f}".format(mrr10))
|