Calvin commited on
Commit
f94a42e
1 Parent(s): 8a1aceb

final touches

Browse files
Files changed (3) hide show
  1. Exploration.ipynb +0 -0
  2. offer_pipeline.py +119 -24
  3. requirements.txt +0 -1
Exploration.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
offer_pipeline.py CHANGED
@@ -3,7 +3,6 @@ from transformers import pipeline
3
  import pickle
4
  import os
5
  import pandas as pd
6
- # import seaborn as sns
7
  import ast
8
  import string
9
  import re
@@ -14,79 +13,165 @@ st.set_page_config(
14
  layout="wide"
15
  )
16
 
17
- pipe = pipeline(task="zero-shot-classification", model="facebook/bart-large-mnli")
 
18
  model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
19
 
 
20
  dire = "DS_NLP_search_data"
21
 
 
22
  @st.cache_data
23
  def get_processed_offers():
 
 
 
 
 
 
24
  processed_offers = pd.read_csv(os.path.join(dire, "processed_offers.csv"))
25
  processed_offers["CATEGORY"] = processed_offers["CATEGORY"].map(ast.literal_eval)
 
26
  return processed_offers
27
 
 
28
  @st.cache_data
29
  def get_categories_data():
 
 
 
 
 
 
 
30
  cats = pd.read_csv(os.path.join(dire, "categories.csv"))
 
31
  return cats
32
 
 
33
  @st.cache_data
34
  def get_offers_data():
 
 
 
 
 
 
 
35
  offers = pd.read_csv(os.path.join(dire, "offer_retailer.csv"))
 
36
  return offers
37
 
 
38
  @st.cache_data
39
  def get_categories(cats_):
 
 
 
 
 
 
 
 
 
 
40
  categories = list(cats_["IS_CHILD_CATEGORY_TO"].unique())
41
  for x in ["Mature"]:
42
  if x in categories:
43
  categories.remove(x)
 
44
  return categories
45
 
 
46
  def check_in_offer(search_str, offer_rets):
 
 
 
 
 
 
 
 
 
 
 
47
  offers = []
48
- # print(offer_rets)
49
  for i in range(len(offer_rets)):
50
  offer_str = offer_rets.iloc[i]["OFFER"]
51
- # print(offer_str)
52
  parsed_str = offer_str.lower().translate(str.maketrans('', '', string.punctuation))
53
  parsed_str = re.sub('[^a-zA-Z0-9 \n\.]', '', parsed_str)
54
- # print(parsed_str)
55
  if search_str.lower() in parsed_str.split(" "):
56
  offers.append(offer_str)
57
  df = pd.DataFrame({"OFFER":offers})
58
- # print(df)
59
  return df
60
 
 
61
  def is_retailer(search_str, threshold=0.5):
 
 
 
 
 
 
 
 
 
 
 
62
  processed_search_str = search_str.lower().capitalize()
63
  labels = pipe(processed_search_str,
64
  candidate_labels=["brand", "retailer", "item"],
65
  )
66
 
67
- return labels["labels"][0] == "retailer" and labels["scores"][0] > threshold
 
 
 
68
 
69
  def perform_cat_inference(search_str, categories, cats, processed_offers):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  labels = pipe(search_str,
71
  candidate_labels=categories,
72
  )
73
- print(labels)
74
  # labels = [l for i, l in enumerate(labels["labels"]) if labels["scores"][i] > 0.20]
75
  filtered_cats = list(cats[cats["IS_CHILD_CATEGORY_TO"].isin(labels["labels"][:3])]["PRODUCT_CATEGORY"].unique())
76
  labels_2 = pipe(search_str,
77
  candidate_labels=filtered_cats,
78
  )
79
- print(labels_2)
80
  top_labels = labels_2["labels"][:3]
81
-
82
-
83
-
84
- print(top_labels)
85
  offers = processed_offers[processed_offers["CATEGORY"].apply(lambda x: bool(set(x) & set(top_labels)))]["OFFER"].reset_index()
86
 
87
  return offers, labels, labels_2
88
 
 
89
  def sort_by_similarity(search_str, related_offers):
 
 
 
 
 
 
 
 
 
 
 
90
  temp_dict = {}
91
  embedding_1 = model.encode(search_str, convert_to_tensor=True)
92
 
@@ -96,42 +181,52 @@ def sort_by_similarity(search_str, related_offers):
96
  temp_dict[offer] = float(util.pytorch_cos_sim(embedding_1, embedding_2))
97
 
98
  sorted_dict = dict(sorted(temp_dict.items(), key=lambda x : x[1], reverse=True))
99
- # casted_scores = list(map(lambda x : int(x), ))
100
  df = pd.DataFrame({"OFFER":list(sorted_dict.keys())[:20], "scores":list(sorted_dict.values())[:20]})
 
101
  return df
102
 
 
103
  def main():
 
104
  col_1, col_2, col_3 = st.columns(3)
105
- search_str = col_2.text_input("Enter a retailer, brand, or category").capitalize()
106
  processed_offers = get_processed_offers()
107
  cats = get_categories_data()
108
  offer_rets = get_offers_data()
109
  categories = get_categories(cats)
110
- # retail_mapping = get_prod_categories()
111
 
112
- if col_2.button("Search", type="primary"):
 
113
  retail = is_retailer(search_str)
114
  direct_offers = check_in_offer(search_str, offer_rets)
115
  col_2.write("Directly related offers")
116
- col_2.table(direct_offers)
 
 
 
 
117
 
118
  if retail:
 
119
  related_offers = offer_rets[~offer_rets["OFFER"].isin(list(direct_offers["OFFER"]))]
120
  else:
 
121
  related_offers, labels_1, labels_2 = perform_cat_inference(search_str, categories, cats, processed_offers)
122
  related_offers = related_offers[~related_offers["OFFER"].isin(list(direct_offers["OFFER"]))]
123
 
 
124
  col_2.table(pd.DataFrame({"labels": labels_1["labels"][:5], "scores": labels_1["scores"][:5]}))
 
125
  col_2.table(pd.DataFrame({"labels": labels_2["labels"][:5], "scores": labels_2["scores"][:5]}))
126
-
127
-
128
- # df = get_confidence_charts(labels_2)
129
- # st.table(df)
130
 
131
  col_2.write("Other related offers")
132
  sorted_offers = sort_by_similarity(search_str, related_offers)
133
- col_2.table(sorted_offers)
134
- if __name__ == "__main__":
135
 
 
 
 
 
 
 
136
  main()
137
 
 
3
  import pickle
4
  import os
5
  import pandas as pd
 
6
  import ast
7
  import string
8
  import re
 
13
  layout="wide"
14
  )
15
 
16
+ # Download and cache models
17
+ pipe = pipeline(task="zero-shot-classification", model="valhalla/distilbart-mnli-12-3")
18
  model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
19
 
20
+ # Directory of csv files
21
  dire = "DS_NLP_search_data"
22
 
23
+ # Use Streamlit caching to load data once
24
  @st.cache_data
25
  def get_processed_offers():
26
+ '''
27
+ Load processed offers from exploration notebook and cache
28
+
29
+ Returns:
30
+ processed_offers (pd.DataFrame) : zero-shot categorized offers
31
+ '''
32
  processed_offers = pd.read_csv(os.path.join(dire, "processed_offers.csv"))
33
  processed_offers["CATEGORY"] = processed_offers["CATEGORY"].map(ast.literal_eval)
34
+
35
  return processed_offers
36
 
37
+
38
  @st.cache_data
39
  def get_categories_data():
40
+ '''
41
+ Load raw category data and cache
42
+
43
+ Returns:
44
+ cats (pd.DataFrame) : raw category data
45
+ '''
46
+
47
  cats = pd.read_csv(os.path.join(dire, "categories.csv"))
48
+
49
  return cats
50
 
51
+
52
  @st.cache_data
53
  def get_offers_data():
54
+ '''
55
+ Load raw offfers data and cache
56
+
57
+ Returns:
58
+ cats (pd.DataFrame) : raw offers data
59
+ '''
60
+
61
  offers = pd.read_csv(os.path.join(dire, "offer_retailer.csv"))
62
+
63
  return offers
64
 
65
+
66
  @st.cache_data
67
  def get_categories(cats_):
68
+ '''
69
+ Extract, load categories and cache
70
+
71
+ Parameters:
72
+ cats_ (pd.DataFrame) : raw categories data
73
+
74
+ Returns:
75
+ categories (List) : child categories
76
+ '''
77
+
78
  categories = list(cats_["IS_CHILD_CATEGORY_TO"].unique())
79
  for x in ["Mature"]:
80
  if x in categories:
81
  categories.remove(x)
82
+
83
  return categories
84
 
85
+
86
  def check_in_offer(search_str, offer_rets):
87
+ '''
88
+ Determine if the input text is directly in the offer with basic string matching
89
+
90
+ Parameters:
91
+ search_str (string) : user text input
92
+ offer_rets (pd.DataFrame) : raw offer data
93
+
94
+ Returns:
95
+ df (pd.DataFrame) : offers with text input
96
+ '''
97
+
98
  offers = []
 
99
  for i in range(len(offer_rets)):
100
  offer_str = offer_rets.iloc[i]["OFFER"]
 
101
  parsed_str = offer_str.lower().translate(str.maketrans('', '', string.punctuation))
102
  parsed_str = re.sub('[^a-zA-Z0-9 \n\.]', '', parsed_str)
103
+
104
  if search_str.lower() in parsed_str.split(" "):
105
  offers.append(offer_str)
106
  df = pd.DataFrame({"OFFER":offers})
107
+
108
  return df
109
 
110
+
111
  def is_retailer(search_str, threshold=0.5):
112
+ '''
113
+ Determine if the text input is highly likely to be a retailer
114
+
115
+ Parameters:
116
+ search_str (string) : user text input
117
+ threshold (int) : probability threshold
118
+
119
+ Returns:
120
+ is_ret (boolean) : true if retailer, false otherwise
121
+ '''
122
+
123
  processed_search_str = search_str.lower().capitalize()
124
  labels = pipe(processed_search_str,
125
  candidate_labels=["brand", "retailer", "item"],
126
  )
127
 
128
+ is_ret = labels["labels"][0] == "retailer" and labels["scores"][0] > threshold
129
+
130
+ return is_ret
131
+
132
 
133
  def perform_cat_inference(search_str, categories, cats, processed_offers):
134
+ '''
135
+ Perform zero shot learning twice and return the offers relevant to the child categories
136
+
137
+ Parameters:
138
+ search_str (string) : user text input
139
+ categories (pd.DataFrame) : list of categories
140
+ cats (pd.DataFrame) : raw category data
141
+ processed_offers (pd.DataFrame) : processed_offer_data
142
+
143
+ Returns:
144
+ offers (pd.DataFrame) : relevant offers
145
+ labels (dict) : parent categories and their probability scores
146
+ labels_2 (dict) : child categories and their probability scores
147
+ '''
148
+
149
  labels = pipe(search_str,
150
  candidate_labels=categories,
151
  )
 
152
  # labels = [l for i, l in enumerate(labels["labels"]) if labels["scores"][i] > 0.20]
153
  filtered_cats = list(cats[cats["IS_CHILD_CATEGORY_TO"].isin(labels["labels"][:3])]["PRODUCT_CATEGORY"].unique())
154
  labels_2 = pipe(search_str,
155
  candidate_labels=filtered_cats,
156
  )
 
157
  top_labels = labels_2["labels"][:3]
 
 
 
 
158
  offers = processed_offers[processed_offers["CATEGORY"].apply(lambda x: bool(set(x) & set(top_labels)))]["OFFER"].reset_index()
159
 
160
  return offers, labels, labels_2
161
 
162
+
163
  def sort_by_similarity(search_str, related_offers):
164
+ '''
165
+ Use sentence embeddings to evaluate the similarity of relevant offers to the text input
166
+
167
+ Parameters:
168
+ search_str (string) : user text input
169
+ related_offers (pd.DataFrame) : relevant offers discovered by zero shot learning
170
+
171
+ Returns:
172
+ df (pd.DataFrame) : relevant offers and their similiarity scores
173
+ '''
174
+
175
  temp_dict = {}
176
  embedding_1 = model.encode(search_str, convert_to_tensor=True)
177
 
 
181
  temp_dict[offer] = float(util.pytorch_cos_sim(embedding_1, embedding_2))
182
 
183
  sorted_dict = dict(sorted(temp_dict.items(), key=lambda x : x[1], reverse=True))
 
184
  df = pd.DataFrame({"OFFER":list(sorted_dict.keys())[:20], "scores":list(sorted_dict.values())[:20]})
185
+
186
  return df
187
 
188
+
189
  def main():
190
+ # Load and cache data
191
  col_1, col_2, col_3 = st.columns(3)
192
+ search_str = col_1.text_input("Enter a retailer, brand, or category").capitalize()
193
  processed_offers = get_processed_offers()
194
  cats = get_categories_data()
195
  offer_rets = get_offers_data()
196
  categories = get_categories(cats)
 
197
 
198
+ if col_1.button("Search", type="primary"):
199
+ # Check offers where the text is directly in it
200
  retail = is_retailer(search_str)
201
  direct_offers = check_in_offer(search_str, offer_rets)
202
  col_2.write("Directly related offers")
203
+
204
+ if len(direct_offers) == 0:
205
+ col_2.write("None found")
206
+ else:
207
+ col_2.table(direct_offers)
208
 
209
  if retail:
210
+ # If retail, we directly compare every offer using sentence embeddings
211
  related_offers = offer_rets[~offer_rets["OFFER"].isin(list(direct_offers["OFFER"]))]
212
  else:
213
+ # Otherwise, we use zero shot learning with processed offers to narrow down our search
214
  related_offers, labels_1, labels_2 = perform_cat_inference(search_str, categories, cats, processed_offers)
215
  related_offers = related_offers[~related_offers["OFFER"].isin(list(direct_offers["OFFER"]))]
216
 
217
+ col_2.write("Parent categories probabilities")
218
  col_2.table(pd.DataFrame({"labels": labels_1["labels"][:5], "scores": labels_1["scores"][:5]}))
219
+ col_2.write("Child categories probabilities")
220
  col_2.table(pd.DataFrame({"labels": labels_2["labels"][:5], "scores": labels_2["scores"][:5]}))
 
 
 
 
221
 
222
  col_2.write("Other related offers")
223
  sorted_offers = sort_by_similarity(search_str, related_offers)
 
 
224
 
225
+ if len(sorted_offers) == 0:
226
+ col_2.write("None found")
227
+ else:
228
+ col_2.table(sorted_offers)
229
+
230
+ if __name__ == "__main__":
231
  main()
232
 
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  streamlit
2
  transformers
3
  pandas
4
- seaborn
5
  torch
6
  sentence-transformers
 
1
  streamlit
2
  transformers
3
  pandas
 
4
  torch
5
  sentence-transformers