|
"""Program Synthesis dataset from dreamcoder. https://github.com/ellisk42/ec""" |
|
from random import choice, shuffle |
|
import datasets |
|
|
|
from dreamcoder.domains.text.makeTextTasks import makeTasks as textMakeTasks |
|
from dreamcoder.domains.list.main import main as listMakeTasks |
|
|
|
|
|
_DESCRIPTION = """\ |
|
Generated program synthesis datasets used to train dreamcoder. |
|
""" |
|
_FEATURES = datasets.Features( |
|
{ |
|
"description": datasets.Value("string"), |
|
"input": datasets.Value("string"), |
|
"output": datasets.Value("string"), |
|
"types": datasets.Value("string") |
|
} |
|
) |
|
|
|
_HOMEPAGE = "https://github.com/ellisk42/ec" |
|
|
|
_LICENSE = "MIT License" |
|
|
|
_MAX_STEPS = 10 |
|
|
|
|
|
class infIterator: |
|
def __init__(self, make_mthd): |
|
self.make_mthd = make_mthd |
|
self.i = None |
|
|
|
def reset(self): |
|
tasks = self.make_mthd() |
|
|
|
rows = [] |
|
for task in tasks: |
|
base = { |
|
'types': str(task.request), |
|
"description": task.name, |
|
} |
|
for (inp, outp) in task.examples: |
|
rows.append(dict(input=str(inp), output=str(outp), **base)) |
|
|
|
shuffle(rows) |
|
self.rows = rows |
|
self.i = 0 |
|
|
|
def step(self): |
|
if self.i is None: |
|
self.reset() |
|
row = self.rows[self.i] |
|
self.i += 1 |
|
if self.i >= len(self.rows): |
|
self.reset() |
|
return row |
|
|
|
|
|
class ProgramSynthesis(datasets.GeneratorBasedBuilder): |
|
"""Program Synthesis dataset from dreamcoder.""" |
|
|
|
VERSION = datasets.Version("1.1.0") |
|
BUILDER_CONFIGS = [ |
|
datasets.BuilderConfig(name="text", version=VERSION, description="Text tasks."), |
|
datasets.BuilderConfig(name="list", version=VERSION, description="List tasks."), |
|
datasets.BuilderConfig(name="all", version=VERSION, description="All tasks at once."), |
|
] |
|
DEFAULT_CONFIG_NAME = "all" |
|
|
|
def _info(self): |
|
return datasets.DatasetInfo( |
|
description=_DESCRIPTION, |
|
features=_FEATURES, |
|
supervised_keys=("input", "output"), |
|
homepage=_HOMEPAGE, |
|
license=_LICENSE, |
|
) |
|
|
|
def _split_generators(self, dl_manager): |
|
return [ |
|
datasets.SplitGenerator( |
|
name=datasets.Split.TRAIN, |
|
), |
|
] |
|
|
|
def _generate_examples(self): |
|
task_samples = { |
|
'text': infIterator(textMakeTasks), |
|
'list': infIterator(listMakeTasks) |
|
} |
|
for key in range(_MAX_STEPS): |
|
|
|
if self.config.name == 'all': |
|
dataset_type = choice(task_samples.keys()) |
|
else: |
|
dataset_type = self.config.name |
|
|
|
yield key, task_samples[dataset_type].step() |
|
|