Anon Anon commited on
Commit
96c49d9
1 Parent(s): 059b5f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +471 -0
app.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ import random
7
+ from matplotlib.ticker import MaxNLocator
8
+ from transformers import pipeline
9
+
10
+ MODEL_NAMES = ["bert-base-uncased", "roberta-base", "bert-large-uncased", "roberta-large"]
11
+ OWN_MODEL_NAME = 'add-a-model'
12
+
13
+ DECIMAL_PLACES = 1
14
+ EPS = 1e-5 # to avoid /0 errors
15
+
16
+ # Example date conts
17
+ DATE_SPLIT_KEY = "DATE"
18
+ START_YEAR = 1801
19
+ STOP_YEAR = 1999
20
+ NUM_PTS = 20
21
+ DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist()
22
+ DATES = [f'{d}' for d in DATES]
23
+
24
+ # Example place conts
25
+ # https://www3.weforum.org/docs/WEF_GGGR_2021.pdf
26
+ # Bottom 10 and top 10 Global Gender Gap ranked countries.
27
+ PLACE_SPLIT_KEY = "PLACE"
28
+ PLACES = [
29
+ "Afghanistan",
30
+ "Yemen",
31
+ "Iraq",
32
+ "Pakistan",
33
+ "Syria",
34
+ "Democratic Republic of Congo",
35
+ "Iran",
36
+ "Mali",
37
+ "Chad",
38
+ "Saudi Arabia",
39
+ "Switzerland",
40
+ "Ireland",
41
+ "Lithuania",
42
+ "Rwanda",
43
+ "Namibia",
44
+ "Sweden",
45
+ "New Zealand",
46
+ "Norway",
47
+ "Finland",
48
+ "Iceland"]
49
+
50
+
51
+ # Example Reddit interest consts
52
+ # in order of increasing self-identified female participation.
53
+ # See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 400000
54
+ SUBREDDITS = [
55
+ "GlobalOffensive",
56
+ "pcmasterrace",
57
+ "nfl",
58
+ "sports",
59
+ "The_Donald",
60
+ "leagueoflegends",
61
+ "Overwatch",
62
+ "gonewild",
63
+ "Futurology",
64
+ "space",
65
+ "technology",
66
+ "gaming",
67
+ "Jokes",
68
+ "dataisbeautiful",
69
+ "woahdude",
70
+ "askscience",
71
+ "wow",
72
+ "anime",
73
+ "BlackPeopleTwitter",
74
+ "politics",
75
+ "pokemon",
76
+ "worldnews",
77
+ "reddit.com",
78
+ "interestingasfuck",
79
+ "videos",
80
+ "nottheonion",
81
+ "television",
82
+ "science",
83
+ "atheism",
84
+ "movies",
85
+ "gifs",
86
+ "Music",
87
+ "trees",
88
+ "EarthPorn",
89
+ "GetMotivated",
90
+ "pokemongo",
91
+ "news",
92
+ # removing below subreddit as most of the tokens are taken up by it:
93
+ # ['ff', '##ff', '##ff', '##fu', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', ...]
94
+ # "fffffffuuuuuuuuuuuu",
95
+ "Fitness",
96
+ "Showerthoughts",
97
+ "OldSchoolCool",
98
+ "explainlikeimfive",
99
+ "todayilearned",
100
+ "gameofthrones",
101
+ "AdviceAnimals",
102
+ "DIY",
103
+ "WTF",
104
+ "IAmA",
105
+ "cringepics",
106
+ "tifu",
107
+ "mildlyinteresting",
108
+ "funny",
109
+ "pics",
110
+ "LifeProTips",
111
+ "creepy",
112
+ "personalfinance",
113
+ "food",
114
+ "AskReddit",
115
+ "books",
116
+ "aww",
117
+ "sex",
118
+ "relationships",
119
+ ]
120
+
121
+ GENDERED_LIST = [
122
+ ['he', 'she'],
123
+ ['him', 'her'],
124
+ ['his', 'hers'],
125
+ ["himself", "herself"],
126
+ ['male', 'female'],
127
+ ['man', 'woman'],
128
+ ['men', 'women'],
129
+ ["husband", "wife"],
130
+ ['father', 'mother'],
131
+ ['boyfriend', 'girlfriend'],
132
+ ['brother', 'sister'],
133
+ ["actor", "actress"],
134
+ ]
135
+
136
+ # %%
137
+ # Fire up the models
138
+ models = dict()
139
+
140
+ for bert_like in MODEL_NAMES:
141
+ models[bert_like] = pipeline("fill-mask", model=bert_like)
142
+
143
+ # %%
144
+
145
+
146
+ def get_gendered_token_ids():
147
+ male_gendered_tokens = [list[0] for list in GENDERED_LIST]
148
+ female_gendered_tokens = [list[1] for list in GENDERED_LIST]
149
+
150
+ return male_gendered_tokens, female_gendered_tokens
151
+
152
+
153
+ def prepare_text_for_masking(input_text, mask_token, gendered_tokens, split_key):
154
+ text_w_masks_list = [
155
+ mask_token if word.lower() in gendered_tokens else word for word in input_text.split()]
156
+ num_masks = len([m for m in text_w_masks_list if m == mask_token])
157
+
158
+ text_portions = ' '.join(text_w_masks_list).split(split_key)
159
+ return text_portions, num_masks
160
+
161
+
162
+ def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token, num_preds):
163
+ pronoun_preds = [sum([
164
+ pronoun["score"] if pronoun["token_str"].strip().lower() in gendered_token else 0.0
165
+ for pronoun in top_preds])
166
+ for top_preds in mask_filled_text
167
+ ]
168
+ return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
169
+
170
+ # %%
171
+
172
+
173
+ def get_figure(df, gender, n_fit=1):
174
+ df = df.set_index('x-axis')
175
+ cols = df.columns
176
+ xs = list(range(len(df)))
177
+ ys = df[cols[0]]
178
+ fig, ax = plt.subplots()
179
+ # Trying small fig due to rendering issues on HF, not on VS Code
180
+ fig.set_figheight(3)
181
+ fig.set_figwidth(9)
182
+
183
+ # find stackoverflow reference
184
+ p, C_p = np.polyfit(xs, ys, n_fit, cov=1)
185
+ t = np.linspace(min(xs)-1, max(xs)+1, 10*len(xs))
186
+ TT = np.vstack([t**(n_fit-i) for i in range(n_fit+1)]).T
187
+
188
+ # matrix multiplication calculates the polynomial values
189
+ yi = np.dot(TT, p)
190
+ C_yi = np.dot(TT, np.dot(C_p, TT.T)) # C_y = TT*C_z*TT.T
191
+ sig_yi = np.sqrt(np.diag(C_yi)) # Standard deviations are sqrt of diagonal
192
+
193
+ ax.fill_between(t, yi+sig_yi, yi-sig_yi, alpha=.25)
194
+ ax.plot(t, yi, '-')
195
+ ax.plot(df, 'ro')
196
+ ax.legend(list(df.columns))
197
+
198
+ ax.axis('tight')
199
+ ax.set_xlabel("Value injected into input text")
200
+ ax.set_title(
201
+ f"Probability of predicting {gender} pronouns.")
202
+ ax.set_ylabel(f"Softmax prob for pronouns")
203
+ ax.xaxis.set_major_locator(MaxNLocator(6))
204
+ ax.tick_params(axis='x', labelrotation=5)
205
+ return fig
206
+
207
+
208
+ # %%
209
+ def predict_gender_pronouns(
210
+ model_name,
211
+ own_model_name,
212
+ indie_vars,
213
+ split_key,
214
+ normalizing,
215
+ n_fit,
216
+ input_text,
217
+ ):
218
+ """Run inference on input_text for each model type, returning df and plots of percentage
219
+ of gender pronouns predicted as female and male in each target text.
220
+ """
221
+ if model_name not in MODEL_NAMES:
222
+ model = pipeline("fill-mask", model=own_model_name)
223
+ else:
224
+ model = models[model_name]
225
+
226
+ mask_token = model.tokenizer.mask_token
227
+
228
+ indie_vars_list = indie_vars.split(',')
229
+
230
+ male_gendered_tokens, female_gendered_tokens = get_gendered_token_ids()
231
+
232
+ text_segments, num_preds = prepare_text_for_masking(
233
+ input_text, mask_token, male_gendered_tokens + female_gendered_tokens, split_key)
234
+
235
+ male_pronoun_preds = []
236
+ female_pronoun_preds = []
237
+ for indie_var in indie_vars_list:
238
+
239
+ target_text = f"{indie_var}".join(text_segments)
240
+ mask_filled_text = model(target_text)
241
+ # Quick hack as realized return type based on how many MASKs in text.
242
+ if type(mask_filled_text[0]) is not list:
243
+ mask_filled_text = [mask_filled_text]
244
+
245
+ female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
246
+ mask_filled_text,
247
+ female_gendered_tokens,
248
+ num_preds
249
+ ))
250
+ male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
251
+ mask_filled_text,
252
+ male_gendered_tokens,
253
+ num_preds
254
+ ))
255
+
256
+ if normalizing:
257
+ total_gendered_probs = np.add(
258
+ female_pronoun_preds, male_pronoun_preds)
259
+ female_pronoun_preds = np.around(
260
+ np.divide(female_pronoun_preds, total_gendered_probs+EPS)*100,
261
+ decimals=DECIMAL_PLACES
262
+ )
263
+ male_pronoun_preds = np.around(
264
+ np.divide(male_pronoun_preds, total_gendered_probs+EPS)*100,
265
+ decimals=DECIMAL_PLACES
266
+ )
267
+
268
+ results_df = pd.DataFrame({'x-axis': indie_vars_list})
269
+ results_df['female_pronouns'] = female_pronoun_preds
270
+ results_df['male_pronouns'] = male_pronoun_preds
271
+ female_fig = get_figure(results_df.drop(
272
+ 'male_pronouns', axis=1), 'female', n_fit,)
273
+ male_fig = get_figure(results_df.drop(
274
+ 'female_pronouns', axis=1), 'male', n_fit,)
275
+ display_text = f"{random.choice(indie_vars_list)}".join(text_segments)
276
+
277
+ return (
278
+ display_text,
279
+ female_fig,
280
+ male_fig,
281
+ results_df,
282
+ )
283
+
284
+
285
+ # %%
286
+ title = "Causing Gender Pronouns"
287
+ description = """
288
+ ## Intro
289
+ """
290
+
291
+
292
+ date_example = [
293
+ MODEL_NAMES[1],
294
+ '',
295
+ ', '.join(DATES),
296
+ 'DATE',
297
+ "False",
298
+ 1,
299
+ 'She was a teenager in DATE.'
300
+ ]
301
+
302
+
303
+ place_example = [
304
+ MODEL_NAMES[0],
305
+ '',
306
+ ', '.join(PLACES),
307
+ 'PLACE',
308
+ "False",
309
+ 1,
310
+ 'She became an adult in PLACE.'
311
+ ]
312
+
313
+
314
+ subreddit_example = [
315
+ MODEL_NAMES[3],
316
+ '',
317
+ ', '.join(SUBREDDITS),
318
+ 'SUBREDDIT',
319
+ "False",
320
+ 1,
321
+ 'She was a kid. SUBREDDIT.'
322
+ ]
323
+
324
+ own_model_example = [
325
+ OWN_MODEL_NAME,
326
+ 'emilyalsentzer/Bio_ClinicalBERT',
327
+ ', '.join(DATES),
328
+ 'DATE',
329
+ "False",
330
+ 1,
331
+ 'She was exposed to the virus in DATE.'
332
+ ]
333
+
334
+
335
+ def date_fn():
336
+ return date_example
337
+
338
+
339
+ def place_fn():
340
+ return place_example
341
+
342
+
343
+ def reddit_fn():
344
+ return subreddit_example
345
+
346
+
347
+ def your_fn():
348
+ return own_model_example
349
+
350
+
351
+ # %%
352
+ demo = gr.Blocks()
353
+ with demo:
354
+ gr.Markdown("# Spurious Correlation Evaluation for Pre-trained LLMs")
355
+ gr.Markdown("Find spurious correlations between seemingly independent variables (for example between `gender` and `time`) in almost any BERT-like LLM on Hugging Face, below.")
356
+
357
+ gr.Markdown("See why this happens how in [our ICLR paper under review](https://openreview.net/pdf?id=25VgHaPz0l4)".)
358
+
359
+ gr.Markdown("## Instructions for this Demo")
360
+ gr.Markdown("1) Click on one of the examples below (where we sweep through a spectrum of `places`, `dates` and `subreddits`) to pre-populate the input fields.")
361
+ gr.Markdown("2) Check out the pre-populated fields as you scroll down to the ['Hit Submit...'] button!")
362
+ gr.Markdown("3) Repeat steps (1) and (2) with more pre-populated inputs or with your own values in the input fields!")
363
+
364
+ gr.Markdown("## Example inputs")
365
+ gr.Markdown("Click a button below to pre-populate input fields with example values. Then scroll down to Hit Submit to generate predictions.")
366
+ with gr.Row():
367
+ date_gen = gr.Button('Click for date example inputs')
368
+ gr.Markdown("<-- x-axis sorted by older to more recent dates:")
369
+
370
+ place_gen = gr.Button('Click for country example inputs')
371
+ gr.Markdown(
372
+ "<-- x-axis sorted by bottom 10 and top 10 [Global Gender Gap](https://www3.weforum.org/docs/WEF_GGGR_2021.pdf) ranked countries:")
373
+
374
+ subreddit_gen = gr.Button('Click for Subreddit example inputs')
375
+ gr.Markdown(
376
+ "<-- x-axis sorted in order of increasing self-identified female participation (see [bburky](http://bburky.com/subredditgenderratios/)): ")
377
+
378
+ your_gen = gr.Button('Add-a-model example inputs')
379
+ gr.Markdown("<-- x-axis dates, with your own model loaded! (If first time, try another example, it can take a while to load new model.)")
380
+
381
+ gr.Markdown("## Input fields")
382
+ gr.Markdown(
383
+ f"A) Pick a spectrum of comma separated values for text injection and x-axis.")
384
+
385
+ with gr.Row():
386
+ x_axis = gr.Textbox(
387
+ lines=3,
388
+ label="A) Comma separated values for text injection and x-axis",
389
+ )
390
+
391
+
392
+ gr.Markdown("B) Pick a pre-loaded BERT-family model of interest on the right.")
393
+ gr.Markdown(f"Or C) select `{OWN_MODEL_NAME}`, then add the mame of any other Hugging Face model that supports the [fill-mask](https://huggingface.co/models?pipeline_tag=fill-mask) task on the right (note: this may take some time to load).")
394
+
395
+ with gr.Row():
396
+ model_name = gr.Radio(
397
+ MODEL_NAMES + [OWN_MODEL_NAME],
398
+ type="value",
399
+ label="B) BERT-like model.",
400
+ )
401
+ own_model_name = gr.Textbox(
402
+ label="C) If you selected an 'add-a-model' model, put any Hugging Face pipeline model name (that supports the fill-mask task) here.",
403
+ )
404
+
405
+ gr.Markdown("D) Pick if you want to the predictions normalied to these gendered terms only.")
406
+ gr.Markdown("E) Also tell the demo what special token you will use in your input text, that you would like replaced with the spectrum of values you listed above.")
407
+ gr.Markdown("And F) the degree of polynomial fit used for high-lighting potential spurious association.")
408
+
409
+
410
+ with gr.Row():
411
+ to_normalize = gr.Dropdown(
412
+ ["False", "True"],
413
+ label="D) Normalize model's predictions to only the gendered ones?",
414
+ type="index",
415
+ )
416
+ place_holder = gr.Textbox(
417
+ label="E) Special token place-holder",
418
+ )
419
+ n_fit = gr.Dropdown(
420
+ list(range(1, 5)),
421
+ label="F) Degree of polynomial fit",
422
+ type="value",
423
+ )
424
+
425
+ gr.Markdown(
426
+ "G) Finally, add input text that includes at least one gendered pronouns and one place-holder token specified above.")
427
+
428
+ with gr.Row():
429
+ input_text = gr.Textbox(
430
+ lines=2,
431
+ label="G) Input text with pronouns and place-holder token",
432
+ )
433
+
434
+ gr.Markdown("## Outputs!")
435
+ #gr.Markdown("Scroll down and 'Hit Submit'!")
436
+ with gr.Row():
437
+ btn = gr.Button("Hit submit to generate predictions!")
438
+
439
+ with gr.Row():
440
+ sample_text = gr.Textbox(
441
+ type="auto", label="Output text: Sample of text fed to model")
442
+ with gr.Row():
443
+ female_fig = gr.Plot(type="auto")
444
+ male_fig = gr.Plot(type="auto")
445
+ with gr.Row():
446
+ df = gr.Dataframe(
447
+ show_label=True,
448
+ overflow_row_behaviour="show_ends",
449
+ label="Table of softmax probability for pronouns predictions",
450
+ )
451
+
452
+ with gr.Row():
453
+
454
+ date_gen.click(date_fn, inputs=[], outputs=[model_name, own_model_name,
455
+ x_axis, place_holder, to_normalize, n_fit, input_text])
456
+ place_gen.click(place_fn, inputs=[], outputs=[
457
+ model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text])
458
+ subreddit_gen.click(reddit_fn, inputs=[], outputs=[
459
+ model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text])
460
+ your_gen.click(your_fn, inputs=[], outputs=[
461
+ model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text])
462
+
463
+ btn.click(
464
+ predict_gender_pronouns,
465
+ inputs=[model_name, own_model_name, x_axis, place_holder,
466
+ to_normalize, n_fit, input_text],
467
+ outputs=[sample_text, female_fig, male_fig, df])
468
+
469
+
470
+ demo.launch(debug=True)
471
+