Pietro Lesci commited on
Commit
51cab9d
1 Parent(s): 02c2d7e
app.py CHANGED
@@ -16,7 +16,9 @@ st.set_page_config(
16
  )
17
 
18
  # session state
19
- session = session_state.get(process=False, run_id=0, posdf=None, negdf=None, uploaded_file_id=0)
 
 
20
 
21
 
22
  # ==== SIDEBAR ==== #
@@ -42,7 +44,9 @@ st.sidebar.markdown("")
42
  st.sidebar.markdown("")
43
  st.sidebar.header("Upload file")
44
  # with st.sidebar.beta_container():
45
- uploaded_file = st.sidebar.file_uploader("Select file", type=[i.name for i in SupportedFiles])
 
 
46
 
47
 
48
  # FOOTER
@@ -62,4 +66,4 @@ with st.beta_container():
62
  st.title("Wordify")
63
 
64
 
65
- page.write(session, uploaded_file)
 
16
  )
17
 
18
  # session state
19
+ session = session_state.get(
20
+ process=False, run_id=0, posdf=None, negdf=None, uploaded_file_id=0
21
+ )
22
 
23
 
24
  # ==== SIDEBAR ==== #
 
44
  st.sidebar.markdown("")
45
  st.sidebar.header("Upload file")
46
  # with st.sidebar.beta_container():
47
+ uploaded_file = st.sidebar.file_uploader(
48
+ "Select file", type=[i.name for i in SupportedFiles]
49
+ )
50
 
51
 
52
  # FOOTER
 
66
  st.title("Wordify")
67
 
68
 
69
+ page.write(session, uploaded_file)
notebooks/wordifier_nb.ipynb CHANGED
@@ -61,11 +61,29 @@
61
  },
62
  {
63
  "cell_type": "code",
64
- "execution_count": 2,
 
 
 
 
 
 
 
 
 
65
  "metadata": {},
66
  "outputs": [],
67
  "source": [
68
- "df = pd.read_excel(\"../data/test_de.xlsx\")\n",
 
 
 
 
 
 
 
 
 
69
  "# mdf = mpd.read_csv(\"../data/test_en.csv\")\n",
70
  "language = \"English\"\n",
71
  "nlp = spacy.load(Languages[language].value, exclude=[\"parser\", \"ner\", \"pos\", \"tok2vec\"])"
@@ -73,7 +91,7 @@
73
  },
74
  {
75
  "cell_type": "code",
76
- "execution_count": 3,
77
  "metadata": {},
78
  "outputs": [],
79
  "source": [
@@ -86,19 +104,14 @@
86
  },
87
  {
88
  "cell_type": "code",
89
- "execution_count": 4,
90
  "metadata": {},
91
  "outputs": [
92
  {
93
  "output_type": "stream",
94
  "name": "stderr",
95
  "text": [
96
- "2021-05-10 18:34:49.425 WARNING root: \n",
97
- " \u001b[33m\u001b[1mWarning:\u001b[0m to view this Streamlit app on a browser, run it with the following\n",
98
- " command:\n",
99
- "\n",
100
- " streamlit run /Users/49796/miniconda3/envs/py38/lib/python3.8/site-packages/ipykernel_launcher.py [ARGUMENTS]\n",
101
- "100%|██████████| 6269/6269 [00:02<00:00, 2750.45it/s]\n"
102
  ]
103
  }
104
  ],
@@ -108,7 +121,7 @@
108
  },
109
  {
110
  "cell_type": "code",
111
- "execution_count": 5,
112
  "metadata": {},
113
  "outputs": [],
114
  "source": [
 
61
  },
62
  {
63
  "cell_type": "code",
64
+ "execution_count": 4,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "path = \"../../../../Downloads/wordify_10000_copy.xlsx\""
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 28,
74
  "metadata": {},
75
  "outputs": [],
76
  "source": [
77
+ "df = pd.read_excel(path, dtype=str).dropna()"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 29,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "# df = pd.read_excel(\"../data/test_de.xlsx\")\n",
87
  "# mdf = mpd.read_csv(\"../data/test_en.csv\")\n",
88
  "language = \"English\"\n",
89
  "nlp = spacy.load(Languages[language].value, exclude=[\"parser\", \"ner\", \"pos\", \"tok2vec\"])"
 
91
  },
92
  {
93
  "cell_type": "code",
94
+ "execution_count": 30,
95
  "metadata": {},
96
  "outputs": [],
97
  "source": [
 
104
  },
105
  {
106
  "cell_type": "code",
107
+ "execution_count": 31,
108
  "metadata": {},
109
  "outputs": [
110
  {
111
  "output_type": "stream",
112
  "name": "stderr",
113
  "text": [
114
+ "100%|██████████| 9939/9939 [00:06<00:00, 1431.09it/s]\n"
 
 
 
 
 
115
  ]
116
  }
117
  ],
 
121
  },
122
  {
123
  "cell_type": "code",
124
+ "execution_count": 32,
125
  "metadata": {},
126
  "outputs": [],
127
  "source": [
src/pages/about.py CHANGED
@@ -31,4 +31,4 @@ def write(*args):
31
  <iframe src="https://www.google.com/maps/embed?pb=!1m18!1m12!1m3!1d2798.949796165441!2d9.185730115812493!3d45.450667779100726!2m3!1f0!2f0!3f0!3m2!1i1024!2i768!4f13.1!3m3!1m2!1s0x4786c405ae6543c9%3A0xf2bb2313b36af88c!2sVia%20Guglielmo%20R%C3%B6ntgen%2C%201%2C%2020136%20Milano%20MI!5e0!3m2!1sit!2sit!4v1569325279433!5m2!1sit!2sit" frameborder="0" style="border:0; width: 100%; height: 312px;" allowfullscreen></iframe>
32
  """,
33
  unsafe_allow_html=True,
34
- )
 
31
  <iframe src="https://www.google.com/maps/embed?pb=!1m18!1m12!1m3!1d2798.949796165441!2d9.185730115812493!3d45.450667779100726!2m3!1f0!2f0!3f0!3m2!1i1024!2i768!4f13.1!3m3!1m2!1s0x4786c405ae6543c9%3A0xf2bb2313b36af88c!2sVia%20Guglielmo%20R%C3%B6ntgen%2C%201%2C%2020136%20Milano%20MI!5e0!3m2!1sit!2sit!4v1569325279433!5m2!1sit!2sit" frameborder="0" style="border:0; width: 100%; height: 312px;" allowfullscreen></iframe>
32
  """,
33
  unsafe_allow_html=True,
34
+ )
src/pages/home.py CHANGED
@@ -1,7 +1,6 @@
1
  from src.configs import Languages
2
  from src.utils import (
3
  encode,
4
- wordifier,
5
  download_button,
6
  TextPreprocessor,
7
  plot_labels_prop,
@@ -9,28 +8,33 @@ from src.utils import (
9
  plot_score,
10
  read_file,
11
  )
 
12
  import streamlit as st
13
 
14
 
15
  def write(session, uploaded_file):
16
 
17
- st.markdown(
18
- """
19
- Hi! Welcome to __Wordify__. Start by uploading a file - CSV, XLSX (avoid Strict Open XML Spreadsheet format [here](https://stackoverflow.com/questions/62800822/openpyxl-cannot-read-strict-open-xml-spreadsheet-format-userwarning-file-conta)),
20
- or PARQUET are currently supported.
21
 
22
- Once you have uploaded the file, __Wordify__ will show an interactive UI through which
23
- you'll be able to interactively decide the text preprocessing steps, their order, and
24
- proceed to Wordify your text.
25
 
26
- If you're ready, let's jump in:
 
 
27
 
28
- :point_left: upload a file via the upload widget in the sidebar!
29
 
30
- NOTE: whenever you want to reset everything, simply refresh the page
31
- """
32
- )
33
- if uploaded_file:
 
 
 
34
 
35
  # 1. READ FILE
36
  with st.spinner("Reading file"):
@@ -38,10 +42,6 @@ def write(session, uploaded_file):
38
  data = read_file(uploaded_file)
39
 
40
  # 2. CREATE UI TO SELECT COLUMNS
41
- st.markdown("")
42
- st.markdown("")
43
- st.header("Process")
44
-
45
  col1, col2, col3 = st.beta_columns(3)
46
  with col1:
47
  language = st.selectbox("Select language", [i.name for i in Languages])
@@ -51,13 +51,16 @@ def write(session, uploaded_file):
51
  )
52
  with col2:
53
  cols_options = [""] + data.columns.tolist()
54
- label_column = st.selectbox("Select label column name", cols_options, index=0)
 
 
55
  with st.beta_expander("Description"):
56
  st.markdown("Select the column containing the label")
57
 
58
  if label_column:
59
  plot = plot_labels_prop(data, label_column)
60
- if plot: st.altair_chart(plot, use_container_width=True)
 
61
 
62
  with col3:
63
  text_column = st.selectbox("Select text column name", cols_options, index=0)
@@ -65,7 +68,9 @@ def write(session, uploaded_file):
65
  st.markdown("Select the column containing the text")
66
 
67
  if text_column:
68
- st.altair_chart(plot_nchars(data, text_column), use_container_width=True)
 
 
69
 
70
  with st.beta_expander("Advanced options"):
71
  # Lemmatization option
@@ -102,14 +107,18 @@ def write(session, uploaded_file):
102
  format_func=lambda x: x.replace("_", " ").title(),
103
  key=session.run_id,
104
  )
105
- lemmatization_options = list(TextPreprocessor._lemmatization_options().keys())
 
 
106
  lemmatization_when = lemmatization_when_elem.selectbox(
107
  "Select when lemmatization happens",
108
  options=lemmatization_options,
109
  index=0,
110
  key=session.run_id,
111
  )
112
- remove_stopwords = remove_stopwords_elem.checkbox("Remove stopwords", value=True, key=session.run_id)
 
 
113
 
114
  # Show sample checkbox
115
  col1, col2 = st.beta_columns([1, 2])
@@ -130,8 +139,14 @@ def write(session, uploaded_file):
130
 
131
  elif show_sample and (label_column and text_column):
132
  sample_data = data.sample(10)
133
- sample_data[f"preprocessed_{text_column}"] = preprocessor.fit_transform(sample_data[text_column]).values
134
- st.table(sample_data.loc[:, [label_column, text_column, f"preprocessed_{text_column}"]])
 
 
 
 
 
 
135
 
136
  # 4. RUN
137
  run_button = st.button("Wordify!")
@@ -142,7 +157,9 @@ def write(session, uploaded_file):
142
 
143
  with st.spinner("Process started"):
144
  # data = data.head()
145
- data[f"preprocessed_{text_column}"] = preprocessor.fit_transform(data[text_column]).values
 
 
146
 
147
  inputs = encode(data[f"preprocessed_{text_column}"], data[label_column])
148
  session.posdf, session.negdf = wordifier(**inputs)
@@ -161,7 +178,9 @@ def write(session, uploaded_file):
161
  col1, col2, col3 = st.beta_columns([2, 3, 3])
162
 
163
  with col1:
164
- label = st.selectbox("Select label", data[label_column].unique().tolist())
 
 
165
  # # with col2:
166
  # thres = st.slider(
167
  # "Select threshold",
@@ -175,14 +194,28 @@ def write(session, uploaded_file):
175
 
176
  with col2:
177
  st.subheader(f"Words __positively__ identifying label `{label}`")
178
- st.write(session.posdf[session.posdf[label_column] == label].sort_values("score", ascending=False))
 
 
 
 
179
  download_button(session.posdf, "positive_data")
180
  if show_plots:
181
- st.altair_chart(plot_score(session.posdf, label_column, label), use_container_width=True)
 
 
 
182
 
183
  with col3:
184
  st.subheader(f"Words __negatively__ identifying label `{label}`")
185
- st.write(session.negdf[session.negdf[label_column] == label].sort_values("score", ascending=False))
 
 
 
 
186
  download_button(session.negdf, "negative_data")
187
  if show_plots:
188
- st.altair_chart(plot_score(session.negdf, label_column, label), use_container_width=True)
 
 
 
 
1
  from src.configs import Languages
2
  from src.utils import (
3
  encode,
 
4
  download_button,
5
  TextPreprocessor,
6
  plot_labels_prop,
 
8
  plot_score,
9
  read_file,
10
  )
11
+ from src.wordifier import wordifier
12
  import streamlit as st
13
 
14
 
15
  def write(session, uploaded_file):
16
 
17
+ if not uploaded_file:
18
+ st.markdown(
19
+ """
20
+ Hi, welcome to __Wordify__! :rocket:
21
 
22
+ Start by uploading a file - CSV, XLSX (avoid Strict Open XML Spreadsheet format [here](https://stackoverflow.com/questions/62800822/openpyxl-cannot-read-strict-open-xml-spreadsheet-format-userwarning-file-conta)),
23
+ or PARQUET are currently supported.
 
24
 
25
+ Once you have uploaded the file, __Wordify__ will show an interactive UI through which
26
+ you'll be able to interactively decide the text preprocessing steps, their order, and
27
+ proceed to Wordify your text.
28
 
29
+ If you're ready, let's jump in:
30
 
31
+ :point_left: upload a file via the upload widget in the sidebar!
32
+
33
+ NOTE: whenever you want to reset everything, simply refresh the page.
34
+ """
35
+ )
36
+
37
+ elif uploaded_file:
38
 
39
  # 1. READ FILE
40
  with st.spinner("Reading file"):
 
42
  data = read_file(uploaded_file)
43
 
44
  # 2. CREATE UI TO SELECT COLUMNS
 
 
 
 
45
  col1, col2, col3 = st.beta_columns(3)
46
  with col1:
47
  language = st.selectbox("Select language", [i.name for i in Languages])
 
51
  )
52
  with col2:
53
  cols_options = [""] + data.columns.tolist()
54
+ label_column = st.selectbox(
55
+ "Select label column name", cols_options, index=0
56
+ )
57
  with st.beta_expander("Description"):
58
  st.markdown("Select the column containing the label")
59
 
60
  if label_column:
61
  plot = plot_labels_prop(data, label_column)
62
+ if plot:
63
+ st.altair_chart(plot, use_container_width=True)
64
 
65
  with col3:
66
  text_column = st.selectbox("Select text column name", cols_options, index=0)
 
68
  st.markdown("Select the column containing the text")
69
 
70
  if text_column:
71
+ st.altair_chart(
72
+ plot_nchars(data, text_column), use_container_width=True
73
+ )
74
 
75
  with st.beta_expander("Advanced options"):
76
  # Lemmatization option
 
107
  format_func=lambda x: x.replace("_", " ").title(),
108
  key=session.run_id,
109
  )
110
+ lemmatization_options = list(
111
+ TextPreprocessor._lemmatization_options().keys()
112
+ )
113
  lemmatization_when = lemmatization_when_elem.selectbox(
114
  "Select when lemmatization happens",
115
  options=lemmatization_options,
116
  index=0,
117
  key=session.run_id,
118
  )
119
+ remove_stopwords = remove_stopwords_elem.checkbox(
120
+ "Remove stopwords", value=True, key=session.run_id
121
+ )
122
 
123
  # Show sample checkbox
124
  col1, col2 = st.beta_columns([1, 2])
 
139
 
140
  elif show_sample and (label_column and text_column):
141
  sample_data = data.sample(10)
142
+ sample_data[f"preprocessed_{text_column}"] = preprocessor.fit_transform(
143
+ sample_data[text_column]
144
+ ).values
145
+ st.table(
146
+ sample_data.loc[
147
+ :, [label_column, text_column, f"preprocessed_{text_column}"]
148
+ ]
149
+ )
150
 
151
  # 4. RUN
152
  run_button = st.button("Wordify!")
 
157
 
158
  with st.spinner("Process started"):
159
  # data = data.head()
160
+ data[f"preprocessed_{text_column}"] = preprocessor.fit_transform(
161
+ data[text_column]
162
+ ).values
163
 
164
  inputs = encode(data[f"preprocessed_{text_column}"], data[label_column])
165
  session.posdf, session.negdf = wordifier(**inputs)
 
178
  col1, col2, col3 = st.beta_columns([2, 3, 3])
179
 
180
  with col1:
181
+ label = st.selectbox(
182
+ "Select label", data[label_column].unique().tolist()
183
+ )
184
  # # with col2:
185
  # thres = st.slider(
186
  # "Select threshold",
 
194
 
195
  with col2:
196
  st.subheader(f"Words __positively__ identifying label `{label}`")
197
+ st.write(
198
+ session.posdf[session.posdf[label_column] == label].sort_values(
199
+ "score", ascending=False
200
+ )
201
+ )
202
  download_button(session.posdf, "positive_data")
203
  if show_plots:
204
+ st.altair_chart(
205
+ plot_score(session.posdf, label_column, label),
206
+ use_container_width=True,
207
+ )
208
 
209
  with col3:
210
  st.subheader(f"Words __negatively__ identifying label `{label}`")
211
+ st.write(
212
+ session.negdf[session.negdf[label_column] == label].sort_values(
213
+ "score", ascending=False
214
+ )
215
+ )
216
  download_button(session.negdf, "negative_data")
217
  if show_plots:
218
+ st.altair_chart(
219
+ plot_score(session.negdf, label_column, label),
220
+ use_container_width=True,
221
+ )
src/plotting.py CHANGED
@@ -22,7 +22,12 @@ def plot_labels_prop(data: pd.DataFrame, label_column: str):
22
 
23
  return
24
 
25
- source = data[label_column].value_counts().reset_index().rename(columns={"index": "Labels", label_column: "Counts"})
 
 
 
 
 
26
  source["Props"] = source["Counts"] / source["Counts"].sum()
27
  source["Proportions"] = (source["Props"].round(3) * 100).map("{:,.2f}".format) + "%"
28
 
@@ -35,7 +40,9 @@ def plot_labels_prop(data: pd.DataFrame, label_column: str):
35
  )
36
  )
37
 
38
- text = bars.mark_text(align="center", baseline="middle", dy=15).encode(text="Proportions:O")
 
 
39
 
40
  return (bars + text).properties(height=300)
41
 
@@ -47,7 +54,9 @@ def plot_nchars(data: pd.DataFrame, text_column: str):
47
  alt.Chart(source)
48
  .mark_bar()
49
  .encode(
50
- alt.X(f"{text_column}:Q", bin=True, axis=alt.Axis(title="# chars per text")),
 
 
51
  alt.Y("count()", axis=alt.Axis(title="")),
52
  )
53
  )
@@ -57,7 +66,11 @@ def plot_nchars(data: pd.DataFrame, text_column: str):
57
 
58
  def plot_score(data: pd.DataFrame, label_col: str, label: str):
59
 
60
- source = data.loc[data[label_col] == label].sort_values("score", ascending=False).head(100)
 
 
 
 
61
 
62
  plot = (
63
  alt.Chart(source)
 
22
 
23
  return
24
 
25
+ source = (
26
+ data[label_column]
27
+ .value_counts()
28
+ .reset_index()
29
+ .rename(columns={"index": "Labels", label_column: "Counts"})
30
+ )
31
  source["Props"] = source["Counts"] / source["Counts"].sum()
32
  source["Proportions"] = (source["Props"].round(3) * 100).map("{:,.2f}".format) + "%"
33
 
 
40
  )
41
  )
42
 
43
+ text = bars.mark_text(align="center", baseline="middle", dy=15).encode(
44
+ text="Proportions:O"
45
+ )
46
 
47
  return (bars + text).properties(height=300)
48
 
 
54
  alt.Chart(source)
55
  .mark_bar()
56
  .encode(
57
+ alt.X(
58
+ f"{text_column}:Q", bin=True, axis=alt.Axis(title="# chars per text")
59
+ ),
60
  alt.Y("count()", axis=alt.Axis(title="")),
61
  )
62
  )
 
66
 
67
  def plot_score(data: pd.DataFrame, label_col: str, label: str):
68
 
69
+ source = (
70
+ data.loc[data[label_col] == label]
71
+ .sort_values("score", ascending=False)
72
+ .head(100)
73
+ )
74
 
75
  plot = (
76
  alt.Chart(source)
src/preprocessing.py CHANGED
@@ -121,7 +121,9 @@ class TextPreprocessor:
121
 
122
  def lemmatizer(doc: spacy.tokens.doc.Doc) -> str:
123
  """Lemmatizes spacy Doc and removes stopwords"""
124
- return " ".join([t.lemma_ for t in doc if t.lemma_ != "-PRON-" and not t.is_stop])
 
 
125
 
126
  else:
127
 
 
121
 
122
  def lemmatizer(doc: spacy.tokens.doc.Doc) -> str:
123
  """Lemmatizes spacy Doc and removes stopwords"""
124
+ return " ".join(
125
+ [t.lemma_ for t in doc if t.lemma_ != "-PRON-" and not t.is_stop]
126
+ )
127
 
128
  else:
129
 
src/session_state.py CHANGED
@@ -100,13 +100,17 @@ def get(**kwargs):
100
  (not hasattr(s, "_main_dg") and s.enqueue == ctx.enqueue)
101
  or
102
  # Streamlit >= 0.65.2
103
- (not hasattr(s, "_main_dg") and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
 
 
 
104
  ):
105
  this_session = s
106
 
107
  if this_session is None:
108
  raise RuntimeError(
109
- "Oh noes. Couldn't get your Streamlit Session object. " "Are you doing something fancy with threads?"
 
110
  )
111
 
112
  # Got the session object! Now let's attach some state into it.
 
100
  (not hasattr(s, "_main_dg") and s.enqueue == ctx.enqueue)
101
  or
102
  # Streamlit >= 0.65.2
103
+ (
104
+ not hasattr(s, "_main_dg")
105
+ and s._uploaded_file_mgr == ctx.uploaded_file_mgr
106
+ )
107
  ):
108
  this_session = s
109
 
110
  if this_session is None:
111
  raise RuntimeError(
112
+ "Oh noes. Couldn't get your Streamlit Session object. "
113
+ "Are you doing something fancy with threads?"
114
  )
115
 
116
  # Got the session object! Now let's attach some state into it.
src/utils.py CHANGED
@@ -55,7 +55,12 @@ def plot_labels_prop(data: pd.DataFrame, label_column: str):
55
 
56
  return
57
 
58
- source = data[label_column].value_counts().reset_index().rename(columns={"index": "Labels", label_column: "Counts"})
 
 
 
 
 
59
  source["Props"] = source["Counts"] / source["Counts"].sum()
60
  source["Proportions"] = (source["Props"].round(3) * 100).map("{:,.2f}".format) + "%"
61
 
@@ -68,7 +73,9 @@ def plot_labels_prop(data: pd.DataFrame, label_column: str):
68
  )
69
  )
70
 
71
- text = bars.mark_text(align="center", baseline="middle", dy=15).encode(text="Proportions:O")
 
 
72
 
73
  return (bars + text).properties(height=300)
74
 
@@ -80,7 +87,9 @@ def plot_nchars(data: pd.DataFrame, text_column: str):
80
  alt.Chart(source)
81
  .mark_bar()
82
  .encode(
83
- alt.X(f"{text_column}:Q", bin=True, axis=alt.Axis(title="# chars per text")),
 
 
84
  alt.Y("count()", axis=alt.Axis(title="")),
85
  )
86
  )
@@ -90,7 +99,11 @@ def plot_nchars(data: pd.DataFrame, text_column: str):
90
 
91
  def plot_score(data: pd.DataFrame, label_col: str, label: str):
92
 
93
- source = data.loc[data[label_col] == label].sort_values("score", ascending=False).head(100)
 
 
 
 
94
 
95
  plot = (
96
  alt.Chart(source)
 
55
 
56
  return
57
 
58
+ source = (
59
+ data[label_column]
60
+ .value_counts()
61
+ .reset_index()
62
+ .rename(columns={"index": "Labels", label_column: "Counts"})
63
+ )
64
  source["Props"] = source["Counts"] / source["Counts"].sum()
65
  source["Proportions"] = (source["Props"].round(3) * 100).map("{:,.2f}".format) + "%"
66
 
 
73
  )
74
  )
75
 
76
+ text = bars.mark_text(align="center", baseline="middle", dy=15).encode(
77
+ text="Proportions:O"
78
+ )
79
 
80
  return (bars + text).properties(height=300)
81
 
 
87
  alt.Chart(source)
88
  .mark_bar()
89
  .encode(
90
+ alt.X(
91
+ f"{text_column}:Q", bin=True, axis=alt.Axis(title="# chars per text")
92
+ ),
93
  alt.Y("count()", axis=alt.Axis(title="")),
94
  )
95
  )
 
99
 
100
  def plot_score(data: pd.DataFrame, label_col: str, label: str):
101
 
102
+ source = (
103
+ data.loc[data[label_col] == label]
104
+ .sort_values("score", ascending=False)
105
+ .head(100)
106
+ )
107
 
108
  plot = (
109
  alt.Chart(source)
src/wordifier.py CHANGED
@@ -43,7 +43,9 @@ def wordifier(X, y, X_names: List[str], y_names: List[str], configs=ModelConfigs
43
  # run randomized regression
44
  clf = LogisticRegression(
45
  penalty="l1",
46
- C=configs.PENALTIES.value[np.random.randint(len(configs.PENALTIES.value))],
 
 
47
  solver="liblinear",
48
  multi_class="auto",
49
  max_iter=500,
@@ -51,7 +53,9 @@ def wordifier(X, y, X_names: List[str], y_names: List[str], configs=ModelConfigs
51
  )
52
 
53
  # sample indices to subsample matrix
54
- selection = resample(np.arange(n_instances), replace=True, stratify=y, n_samples=sample_size)
 
 
55
 
56
  # fit
57
  try:
@@ -74,14 +78,28 @@ def wordifier(X, y, X_names: List[str], y_names: List[str], configs=ModelConfigs
74
  neg_scores = neg_scores / configs.NUM_ITERS.value
75
 
76
  # get only active features
77
- pos_positions = np.where(pos_scores >= configs.SELECTION_THRESHOLD.value, pos_scores, 0)
78
- neg_positions = np.where(neg_scores >= configs.SELECTION_THRESHOLD.value, neg_scores, 0)
 
 
 
 
79
 
80
  # prepare DataFrame
81
- pos = [(X_names[i], pos_scores[c, i], y_names[c]) for c, i in zip(*pos_positions.nonzero())]
82
- neg = [(X_names[i], neg_scores[c, i], y_names[c]) for c, i in zip(*neg_positions.nonzero())]
83
-
84
- posdf = pd.DataFrame(pos, columns="word score label".split()).sort_values(["label", "score"], ascending=False)
85
- negdf = pd.DataFrame(neg, columns="word score label".split()).sort_values(["label", "score"], ascending=False)
 
 
 
 
 
 
 
 
 
 
86
 
87
  return posdf, negdf
 
43
  # run randomized regression
44
  clf = LogisticRegression(
45
  penalty="l1",
46
+ C=configs.PENALTIES.value[
47
+ np.random.randint(len(configs.PENALTIES.value))
48
+ ],
49
  solver="liblinear",
50
  multi_class="auto",
51
  max_iter=500,
 
53
  )
54
 
55
  # sample indices to subsample matrix
56
+ selection = resample(
57
+ np.arange(n_instances), replace=True, stratify=y, n_samples=sample_size
58
+ )
59
 
60
  # fit
61
  try:
 
78
  neg_scores = neg_scores / configs.NUM_ITERS.value
79
 
80
  # get only active features
81
+ pos_positions = np.where(
82
+ pos_scores >= configs.SELECTION_THRESHOLD.value, pos_scores, 0
83
+ )
84
+ neg_positions = np.where(
85
+ neg_scores >= configs.SELECTION_THRESHOLD.value, neg_scores, 0
86
+ )
87
 
88
  # prepare DataFrame
89
+ pos = [
90
+ (X_names[i], pos_scores[c, i], y_names[c])
91
+ for c, i in zip(*pos_positions.nonzero())
92
+ ]
93
+ neg = [
94
+ (X_names[i], neg_scores[c, i], y_names[c])
95
+ for c, i in zip(*neg_positions.nonzero())
96
+ ]
97
+
98
+ posdf = pd.DataFrame(pos, columns="word score label".split()).sort_values(
99
+ ["label", "score"], ascending=False
100
+ )
101
+ negdf = pd.DataFrame(neg, columns="word score label".split()).sort_values(
102
+ ["label", "score"], ascending=False
103
+ )
104
 
105
  return posdf, negdf