Spaces:
Runtime error
Runtime error
File size: 1,710 Bytes
75c3a89 75efc41 75c3a89 a41bdbc 75c3a89 a41bdbc 5cd1ac6 a41bdbc 5cd1ac6 31f3439 6e03e5d 75c3a89 73ee9f2 75c3a89 75efc41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import gzip
import json
import numpy as np
import pandas as pd
import streamlit as st
import torch
import tqdm
from sentence_transformers import SentenceTransformer
@st.cache(allow_output_mutation=True)
def load_model(model_name, model_dict):
assert model_name in model_dict.keys()
# Lazy downloading
model_ids = model_dict[model_name]
if type(model_ids) == str:
output = SentenceTransformer(model_ids)
elif hasattr(model_ids, '__iter__'):
output = [SentenceTransformer(name) for name in model_ids]
return output
@st.cache(allow_output_mutation=True)
def load_embeddings():
# embedding pre-generated
corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-distilbert-emb.csv', max_rows=10000))
return corpus_emb.float()
@st.cache(allow_output_mutation=True)
def filter_questions(tag, max_questions=10000):
posts = []
max_posts = 6e6
with gzip.open("./data/stackoverflow-titles.jsonl.gz", "rt") as fIn:
for line in tqdm.auto.tqdm(fIn, total=max_posts, desc="Load data"):
posts.append(json.loads(line))
if len(posts) >= max_posts:
break
filtered_posts = []
for post in posts:
if tag in post["tags"]:
filtered_posts.append(post)
if len(filtered_posts) >= max_questions:
break
return filtered_posts
def load_gender_data():
df = load_gendered_dataset()
sampled_row = df.sample().iloc[0]
return sampled_row.base_sentence, sampled_row.male_sentence, sampled_row.female_sentence
@st.cache(allow_output_mutation=True)
def load_gendered_dataset():
df = pd.read_csv('./data/bias_evaluation.csv')
return df |