Spaces:
Runtime error
Runtime error
from datasets import load_dataset | |
from disaggregators import Disaggregator, DisaggregationModuleLabels, CustomDisaggregator | |
from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig | |
import matplotlib | |
matplotlib.use('TKAgg') | |
import joblib | |
import os | |
cache_file = "cached_data.pkl" | |
cache_dict = {} | |
if os.path.exists(cache_file): | |
cache_dict = joblib.load("cached_data.pkl") | |
class MeSHAgeLabels(AgeLabels): | |
INFANT = "infant" | |
CHILD_PRESCHOOL = "child_preschool" | |
CHILD = "child" | |
ADOLESCENT = "adolescent" | |
ADULT = "adult" | |
MIDDLE_AGED = "middle_aged" | |
AGED = "aged" | |
AGED_80_OVER = "aged_80_over" | |
age = Age( | |
config=AgeConfig( | |
labels=MeSHAgeLabels, | |
ages=[list(MeSHAgeLabels)], | |
breakpoints=[0, 2, 5, 12, 18, 44, 64, 79] | |
), | |
column="question" | |
) | |
class TabsSpacesLabels(DisaggregationModuleLabels): | |
TABS = "tabs" | |
SPACES = "spaces" | |
class TabsSpaces(CustomDisaggregator): | |
module_id = "tabs_spaces" | |
labels = TabsSpacesLabels | |
def __call__(self, row, *args, **kwargs): | |
if "\t" in row[self.column]: | |
return {self.labels.TABS: True, self.labels.SPACES: False} | |
else: | |
return {self.labels.TABS: False, self.labels.SPACES: True} | |
class ReactComponentLabels(DisaggregationModuleLabels): | |
CLASS = "class" | |
FUNCTION = "function" | |
class ReactComponent(CustomDisaggregator): | |
module_id = "react_component" | |
labels = ReactComponentLabels | |
def __call__(self, row, *args, **kwargs): | |
if "extends React.Component" in row[self.column] or "extends Component" in row[self.column]: | |
return {self.labels.CLASS: True, self.labels.FUNCTION: False} | |
else: | |
return {self.labels.CLASS: False, self.labels.FUNCTION: True} | |
configs = { | |
"laion": { | |
"disaggregation_modules": ["continent"], | |
"dataset_name": "society-ethics/laion2B-en_continents", | |
"column": "TEXT", | |
"feature_names": { | |
"continent.africa": "Africa", | |
"continent.americas": "Americas", | |
"continent.asia": "Asia", | |
"continent.europe": "Europe", | |
"continent.oceania": "Oceania", | |
# Parent level | |
"continent": "Continent", | |
} | |
}, | |
"medmcqa": { | |
"disaggregation_modules": [age, "gender"], | |
"dataset_name": "society-ethics/medmcqa_age_gender_custom", | |
"column": "question", | |
"feature_names": { | |
"age.infant": "Infant", | |
"age.child_preschool": "Preschool", | |
"age.child": "Child", | |
"age.adolescent": "Adolescent", | |
"age.adult": "Adult", | |
"age.middle_aged": "Middle Aged", | |
"age.aged": "Aged", | |
"age.aged_80_over": "Aged 80+", | |
"gender.male": "Male", | |
"gender.female": "Female", | |
# Parent level | |
"gender": "Gender", | |
"age": "Age", | |
"Both": "Age + Gender", | |
} | |
}, | |
"stack": { | |
"disaggregation_modules": [TabsSpaces, ReactComponent], | |
"dataset_name": "society-ethics/the-stack-tabs_spaces", | |
"column": "content", | |
"feature_names": { | |
"react_component.class": "Class", | |
"react_component.function": "Function", | |
"tabs_spaces.tabs": "Tabs", | |
"tabs_spaces.spaces": "Spaces", | |
# Parent level | |
"react_component": "React Component Syntax", | |
"tabs_spaces": "Tabs vs. Spaces", | |
"Both": "React Component Syntax + Tabs vs. Spaces", | |
} | |
} | |
} | |
def generate_cached_data(disaggregation_modules, dataset_name, column, feature_names): | |
disaggregator = Disaggregator(disaggregation_modules, column=column) | |
ds = load_dataset(dataset_name, split="train") | |
df = ds.to_pandas() | |
all_fields = {*disaggregator.fields, "None"} | |
distributions = df[sorted(list(disaggregator.fields))].value_counts() | |
return { | |
"fields": all_fields, | |
"data_fields": disaggregator.fields, | |
"distributions": distributions, | |
"disaggregators": [module.name for module in disaggregator.modules], | |
"column": column, | |
"feature_names": feature_names, | |
} | |
cache_dict.update({ | |
"laion": generate_cached_data(**configs["laion"]), | |
"medmcqa": generate_cached_data(**configs["medmcqa"]), | |
"stack": generate_cached_data(**configs["stack"]) | |
}) | |
joblib.dump(cache_dict, cache_file) | |