Eran Fainman commited on
Commit
071fa8f
·
1 Parent(s): 83fd180

Add application file

Browse files
Files changed (1) hide show
  1. app.py +19 -5
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import uuid
 
3
  import pandas as pd
4
  import streamlit as st
5
  import huggingface_hub as hh
@@ -24,7 +25,6 @@ GROUPS_INFO_PATH = os.path.join(CACHE_PATH, "groups-info")
24
 
25
  REQUIRED_COLUMNS = ['accommodation_id', 'user_id'] + [f'review_{i}' for i in range(1, 11)]
26
 
27
-
28
  API = hh.HfApi(token=TOKEN)
29
 
30
 
@@ -54,12 +54,26 @@ def refresh_data():
54
  refresh_data()
55
 
56
 
 
 
 
 
 
 
 
57
  def calculate_metrics(df_pred):
58
  df_gt = pd.read_csv(os.path.join(GT_PATH, 'val_matches.csv'))
 
 
 
 
 
 
 
 
 
59
 
60
- mrr10 = 0.3
61
- precision10 = 0.2
62
- return mrr10, precision10
63
 
64
 
65
  def get_group_name_by_email(email):
@@ -148,7 +162,7 @@ def render():
148
  submission_tab.markdown("**Group name:** " + group_name)
149
  submission_tab.markdown("**Model name:** " + model_name)
150
  submission_tab.markdown("**Revision:** " + str(revision) +
151
- f"(out of {MAX_SUBMISSIONS} allowed submissions)")
152
 
153
  submission_tab.write("### Submission results")
154
  submission_tab.markdown("**MRR@10:** {:.4f}".format(mrr10))
 
1
  import os
2
  import uuid
3
+ import numpy as np
4
  import pandas as pd
5
  import streamlit as st
6
  import huggingface_hub as hh
 
25
 
26
  REQUIRED_COLUMNS = ['accommodation_id', 'user_id'] + [f'review_{i}' for i in range(1, 11)]
27
 
 
28
  API = hh.HfApi(token=TOKEN)
29
 
30
 
 
54
  refresh_data()
55
 
56
 
57
+ def get_match_index(row):
58
+ for i in range(1, 11):
59
+ if row['review_id'] == row[f'review_{i}']:
60
+ return i
61
+ return np.inf
62
+
63
+
64
  def calculate_metrics(df_pred):
65
  df_gt = pd.read_csv(os.path.join(GT_PATH, 'val_matches.csv'))
66
+ if len(df_pred) != len(df_gt):
67
+ raise Exception("Your predictions file should contain {} rows, only {} rows were found in the file".format(
68
+ len(df_gt), len(df_pred)
69
+ ))
70
+
71
+ df_merged = pd.merge(df_gt, df_pred, how='left', on=['accommodation_id', 'user_id']).fillna('')
72
+ df_merged['match_index'] = df_merged.apply(get_match_index, axis=1)
73
+ df_merged['mrr10'] = df_merged['match_index'].apply(lambda x: 1/x)
74
+ df_merged['precision10'] = df_merged.apply(lambda x: 1 if x != np.inf else 0)
75
 
76
+ return df_merged['mrr10'].mean(), df_merged['precision10'].mean()
 
 
77
 
78
 
79
  def get_group_name_by_email(email):
 
162
  submission_tab.markdown("**Group name:** " + group_name)
163
  submission_tab.markdown("**Model name:** " + model_name)
164
  submission_tab.markdown("**Revision:** " + str(revision) +
165
+ f" (out of {MAX_SUBMISSIONS} allowed submissions)")
166
 
167
  submission_tab.write("### Submission results")
168
  submission_tab.markdown("**MRR@10:** {:.4f}".format(mrr10))