Trent
Add gender evaluation demo
75efc41
raw
history blame
1.71 kB
import gzip
import json
import numpy as np
import pandas as pd
import streamlit as st
import torch
import tqdm
from sentence_transformers import SentenceTransformer
@st.cache(allow_output_mutation=True)
def load_model(model_name, model_dict):
assert model_name in model_dict.keys()
# Lazy downloading
model_ids = model_dict[model_name]
if type(model_ids) == str:
output = SentenceTransformer(model_ids)
elif hasattr(model_ids, '__iter__'):
output = [SentenceTransformer(name) for name in model_ids]
return output
@st.cache(allow_output_mutation=True)
def load_embeddings():
# embedding pre-generated
corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-distilbert-emb.csv', max_rows=10000))
return corpus_emb.float()
@st.cache(allow_output_mutation=True)
def filter_questions(tag, max_questions=10000):
posts = []
max_posts = 6e6
with gzip.open("./data/stackoverflow-titles.jsonl.gz", "rt") as fIn:
for line in tqdm.auto.tqdm(fIn, total=max_posts, desc="Load data"):
posts.append(json.loads(line))
if len(posts) >= max_posts:
break
filtered_posts = []
for post in posts:
if tag in post["tags"]:
filtered_posts.append(post)
if len(filtered_posts) >= max_questions:
break
return filtered_posts
def load_gender_data():
df = load_gendered_dataset()
sampled_row = df.sample().iloc[0]
return sampled_row.base_sentence, sampled_row.male_sentence, sampled_row.female_sentence
@st.cache(allow_output_mutation=True)
def load_gendered_dataset():
df = pd.read_csv('./data/bias_evaluation.csv')
return df