disaggregators / generate_datasets.py
dawood's picture
dawood HF staff
Duplicate from society-ethics/disaggregators
924d3bd
from datasets import load_dataset
from disaggregators import Disaggregator
from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig
class MeSHAgeLabels(AgeLabels):
INFANT = "infant"
CHILD_PRESCHOOL = "child_preschool"
CHILD = "child"
ADOLESCENT = "adolescent"
ADULT = "adult"
MIDDLE_AGED = "middle_aged"
AGED = "aged"
AGED_80_OVER = "aged_80_over"
age = Age(
config=AgeConfig(
labels=MeSHAgeLabels,
ages=[
MeSHAgeLabels.INFANT,
MeSHAgeLabels.CHILD_PRESCHOOL,
MeSHAgeLabels.CHILD,
MeSHAgeLabels.ADOLESCENT,
MeSHAgeLabels.ADULT,
MeSHAgeLabels.MIDDLE_AGED,
MeSHAgeLabels.AGED,
MeSHAgeLabels.AGED_80_OVER
],
breakpoints=[0, 2, 5, 12, 18, 44, 64, 79]
),
column="question"
)
disaggregator = Disaggregator([age, "gender"], column="question")
ds = load_dataset("medmcqa", split="train")
ds_mapped = ds.map(disaggregator)
ds_mapped.push_to_hub("society-ethics/medmcqa_age_gender_custom")