File size: 1,710 Bytes
75c3a89
 
 
75efc41
75c3a89
a41bdbc
75c3a89
 
a41bdbc
 
 
 
5cd1ac6
 
a41bdbc
5cd1ac6
31f3439
 
 
 
6e03e5d
 
75c3a89
 
 
 
73ee9f2
75c3a89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75efc41
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gzip
import json
import numpy as np
import pandas as pd

import streamlit as st
import torch
import tqdm
from sentence_transformers import SentenceTransformer


@st.cache(allow_output_mutation=True)
def load_model(model_name, model_dict):
    assert model_name in model_dict.keys()
    # Lazy downloading
    model_ids = model_dict[model_name]
    if type(model_ids) == str:
        output = SentenceTransformer(model_ids)
    elif hasattr(model_ids, '__iter__'):
        output = [SentenceTransformer(name) for name in model_ids]

    return output

@st.cache(allow_output_mutation=True)
def load_embeddings():
    # embedding pre-generated
    corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-distilbert-emb.csv', max_rows=10000))
    return corpus_emb.float()

@st.cache(allow_output_mutation=True)
def filter_questions(tag, max_questions=10000):
    posts = []
    max_posts = 6e6
    with gzip.open("./data/stackoverflow-titles.jsonl.gz", "rt") as fIn:
        for line in tqdm.auto.tqdm(fIn, total=max_posts, desc="Load data"):
            posts.append(json.loads(line))

            if len(posts) >= max_posts:
                break

    filtered_posts = []
    for post in posts:
        if tag in post["tags"]:
            filtered_posts.append(post)
            if len(filtered_posts) >= max_questions:
                break
    return filtered_posts

def load_gender_data():
    df = load_gendered_dataset()
    sampled_row = df.sample().iloc[0]
    return sampled_row.base_sentence, sampled_row.male_sentence, sampled_row.female_sentence

@st.cache(allow_output_mutation=True)
def load_gendered_dataset():
    df = pd.read_csv('./data/bias_evaluation.csv')
    return df