from datasets import load_dataset from disaggregators import Disaggregator, DisaggregationModuleLabels, CustomDisaggregator from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig import matplotlib matplotlib.use('TKAgg') import joblib import os cache_file = "cached_data.pkl" cache_dict = {} if os.path.exists(cache_file): cache_dict = joblib.load("cached_data.pkl") class MeSHAgeLabels(AgeLabels): INFANT = "infant" CHILD_PRESCHOOL = "child_preschool" CHILD = "child" ADOLESCENT = "adolescent" ADULT = "adult" MIDDLE_AGED = "middle_aged" AGED = "aged" AGED_80_OVER = "aged_80_over" age = Age( config=AgeConfig( labels=MeSHAgeLabels, ages=[list(MeSHAgeLabels)], breakpoints=[0, 2, 5, 12, 18, 44, 64, 79] ), column="question" ) class TabsSpacesLabels(DisaggregationModuleLabels): TABS = "tabs" SPACES = "spaces" class TabsSpaces(CustomDisaggregator): module_id = "tabs_spaces" labels = TabsSpacesLabels def __call__(self, row, *args, **kwargs): if "\t" in row[self.column]: return {self.labels.TABS: True, self.labels.SPACES: False} else: return {self.labels.TABS: False, self.labels.SPACES: True} class ReactComponentLabels(DisaggregationModuleLabels): CLASS = "class" FUNCTION = "function" class ReactComponent(CustomDisaggregator): module_id = "react_component" labels = ReactComponentLabels def __call__(self, row, *args, **kwargs): if "extends React.Component" in row[self.column] or "extends Component" in row[self.column]: return {self.labels.CLASS: True, self.labels.FUNCTION: False} else: return {self.labels.CLASS: False, self.labels.FUNCTION: True} configs = { "laion": { "disaggregation_modules": ["continent"], "dataset_name": "society-ethics/laion2B-en_continents", "column": "TEXT", "feature_names": { "continent.africa": "Africa", "continent.americas": "Americas", "continent.asia": "Asia", "continent.europe": "Europe", "continent.oceania": "Oceania", # Parent level "continent": "Continent", } }, "medmcqa": { "disaggregation_modules": [age, "gender"], "dataset_name": "society-ethics/medmcqa_age_gender_custom", "column": "question", "feature_names": { "age.infant": "Infant", "age.child_preschool": "Preschool", "age.child": "Child", "age.adolescent": "Adolescent", "age.adult": "Adult", "age.middle_aged": "Middle Aged", "age.aged": "Aged", "age.aged_80_over": "Aged 80+", "gender.male": "Male", "gender.female": "Female", # Parent level "gender": "Gender", "age": "Age", "Both": "Age + Gender", } }, "stack": { "disaggregation_modules": [TabsSpaces, ReactComponent], "dataset_name": "society-ethics/the-stack-tabs_spaces", "column": "content", "feature_names": { "react_component.class": "Class", "react_component.function": "Function", "tabs_spaces.tabs": "Tabs", "tabs_spaces.spaces": "Spaces", # Parent level "react_component": "React Component Syntax", "tabs_spaces": "Tabs vs. Spaces", "Both": "React Component Syntax + Tabs vs. Spaces", } } } def generate_cached_data(disaggregation_modules, dataset_name, column, feature_names): disaggregator = Disaggregator(disaggregation_modules, column=column) ds = load_dataset(dataset_name, split="train") df = ds.to_pandas() all_fields = {*disaggregator.fields, "None"} distributions = df[sorted(list(disaggregator.fields))].value_counts() return { "fields": all_fields, "data_fields": disaggregator.fields, "distributions": distributions, "disaggregators": [module.name for module in disaggregator.modules], "column": column, "feature_names": feature_names, } cache_dict.update({ "laion": generate_cached_data(**configs["laion"]), "medmcqa": generate_cached_data(**configs["medmcqa"]), "stack": generate_cached_data(**configs["stack"]) }) joblib.dump(cache_dict, cache_file)