Fraser-Greenlee
commited on
Commit
•
e1c1753
1
Parent(s):
4c34db7
add dreamcoder codebase
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- dreamcoder/__init__.py +107 -0
- dreamcoder/compression.py +282 -0
- dreamcoder/deprecated/__init__.py +0 -0
- dreamcoder/deprecated/network.py +479 -0
- dreamcoder/differentiation.py +393 -0
- dreamcoder/domains/__init__.py +0 -0
- dreamcoder/domains/arithmetic/__init__.py +0 -0
- dreamcoder/domains/arithmetic/arithmeticPrimitives.py +58 -0
- dreamcoder/domains/list/__init__.py +0 -0
- dreamcoder/domains/list/listPrimitives.py +546 -0
- dreamcoder/domains/list/main.py +410 -0
- dreamcoder/domains/list/makeListTasks.py +587 -0
- dreamcoder/domains/logo/__init__.py +0 -0
- dreamcoder/domains/logo/logoPrimitives.py +41 -0
- dreamcoder/domains/logo/main.py +450 -0
- dreamcoder/domains/logo/makeLogoTasks.py +777 -0
- dreamcoder/domains/misc/RobustFillPrimitives.py +308 -0
- dreamcoder/domains/misc/__init__.py +0 -0
- dreamcoder/domains/misc/algolispPrimitives.py +508 -0
- dreamcoder/domains/misc/deepcoderPrimitives.py +352 -0
- dreamcoder/domains/misc/napsPrimitives.py +198 -0
- dreamcoder/domains/regex/__init__.py +0 -0
- dreamcoder/domains/regex/groundtruthRegexes.py +172 -0
- dreamcoder/domains/regex/main.py +384 -0
- dreamcoder/domains/regex/makeRegexTasks.py +347 -0
- dreamcoder/domains/regex/regexPrimitives.py +367 -0
- dreamcoder/domains/text/__init__.py +0 -0
- dreamcoder/domains/text/main.py +270 -0
- dreamcoder/domains/text/makeTextTasks.py +424 -0
- dreamcoder/domains/text/textPrimitives.py +87 -0
- dreamcoder/domains/tower/__init__.py +0 -0
- dreamcoder/domains/tower/main.py +359 -0
- dreamcoder/domains/tower/makeTowerTasks.py +556 -0
- dreamcoder/domains/tower/towerPrimitives.py +152 -0
- dreamcoder/domains/tower/tower_common.py +173 -0
- dreamcoder/dreamcoder.py +1074 -0
- dreamcoder/dreaming.py +90 -0
- dreamcoder/ec.py +3 -0
- dreamcoder/enumeration.py +469 -0
- dreamcoder/fragmentGrammar.py +430 -0
- dreamcoder/fragmentUtilities.py +405 -0
- dreamcoder/frontier.py +247 -0
- dreamcoder/grammar.py +1308 -0
- dreamcoder/likelihoodModel.py +407 -0
- dreamcoder/primitiveGraph.py +182 -0
- dreamcoder/program.py +1214 -0
- dreamcoder/recognition.py +1528 -0
- dreamcoder/task.py +244 -0
- dreamcoder/taskBatcher.py +200 -0
- dreamcoder/type.py +378 -0
dreamcoder/__init__.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
EC codebase Python library (AKA the "frontend")
|
3 |
+
|
4 |
+
Module mapping details:
|
5 |
+
|
6 |
+
TODO: remove module mapping code when backwards-compatibility is no longer required.
|
7 |
+
|
8 |
+
The below module mapping is required for backwards-compatibility with old pickle files
|
9 |
+
generated from before the EC codebase refactor. New files added to the codebase do not
|
10 |
+
need to be added to the mapping, but if the existing modules are moved, then this the
|
11 |
+
mapping needs to be updated to reflect the move or rename.
|
12 |
+
|
13 |
+
The mapping uses the following pattern:
|
14 |
+
|
15 |
+
sys.modules[<old module path>] = <new module reference>
|
16 |
+
|
17 |
+
This is because the previous structure of the codebase was completely flat, and when refactoring
|
18 |
+
to a hierarchical files, loading previous pickle files no longer works properly. It is important
|
19 |
+
to retain the ability to read old pickle files generated from official experiments. As a workaround,
|
20 |
+
the old module paths are included below. A preferable alternative would be to export program state
|
21 |
+
into JSON files instead of pickle files to avoid issues where the underlying classes change, so that
|
22 |
+
could be a future improvement to this project. Until then, we use the module mapping workaround.
|
23 |
+
|
24 |
+
For more info, see this StackOverflow answer: https://stackoverflow.com/a/2121918/2573242
|
25 |
+
"""
|
26 |
+
import sys
|
27 |
+
|
28 |
+
from dreamcoder import differentiation
|
29 |
+
from dreamcoder import dreamcoder
|
30 |
+
from dreamcoder import enumeration
|
31 |
+
from dreamcoder import fragmentGrammar
|
32 |
+
from dreamcoder import fragmentUtilities
|
33 |
+
from dreamcoder import frontier
|
34 |
+
from dreamcoder import grammar
|
35 |
+
from dreamcoder import likelihoodModel
|
36 |
+
from dreamcoder import program
|
37 |
+
from dreamcoder import primitiveGraph
|
38 |
+
try:
|
39 |
+
from dreamcoder import recognition
|
40 |
+
except:
|
41 |
+
print("Failure loading recognition - only acceptable if using pypy ",file=sys.stderr)
|
42 |
+
from dreamcoder import task
|
43 |
+
from dreamcoder import taskBatcher
|
44 |
+
from dreamcoder import type
|
45 |
+
from dreamcoder import utilities
|
46 |
+
from dreamcoder import vs
|
47 |
+
from dreamcoder.domains.misc import algolispPrimitives, deepcoderPrimitives
|
48 |
+
from dreamcoder.domains.misc import RobustFillPrimitives
|
49 |
+
from dreamcoder.domains.misc import napsPrimitives
|
50 |
+
from dreamcoder.domains.tower import makeTowerTasks
|
51 |
+
from dreamcoder.domains.tower import towerPrimitives
|
52 |
+
from dreamcoder.domains.tower import tower_common
|
53 |
+
from dreamcoder.domains.tower import main as tower_main
|
54 |
+
from dreamcoder.domains.regex import groundtruthRegexes
|
55 |
+
from dreamcoder.domains.regex import regexPrimitives
|
56 |
+
from dreamcoder.domains.regex import makeRegexTasks
|
57 |
+
#from dreamcoder.domains.regex import main as regex_main
|
58 |
+
from dreamcoder.domains.logo import logoPrimitives
|
59 |
+
from dreamcoder.domains.logo import makeLogoTasks
|
60 |
+
from dreamcoder.domains.logo import main as logo_main
|
61 |
+
from dreamcoder.domains.list import listPrimitives
|
62 |
+
from dreamcoder.domains.list import makeListTasks
|
63 |
+
from dreamcoder.domains.list import main as list_main
|
64 |
+
from dreamcoder.domains.arithmetic import arithmeticPrimitives
|
65 |
+
from dreamcoder.domains.text import textPrimitives
|
66 |
+
from dreamcoder.domains.text import makeTextTasks
|
67 |
+
from dreamcoder.domains.text import main as text_main
|
68 |
+
|
69 |
+
sys.modules['differentiation'] = differentiation
|
70 |
+
sys.modules['ec'] = dreamcoder
|
71 |
+
sys.modules['enumeration'] = enumeration
|
72 |
+
sys.modules['fragmentGrammar'] = fragmentGrammar
|
73 |
+
sys.modules['fragmentUtilities'] = fragmentUtilities
|
74 |
+
sys.modules['frontier'] = frontier
|
75 |
+
sys.modules['grammar'] = grammar
|
76 |
+
sys.modules['likelihoodModel'] = likelihoodModel
|
77 |
+
sys.modules['program'] = program
|
78 |
+
try: sys.modules['recognition'] = recognition
|
79 |
+
except: pass
|
80 |
+
sys.modules['task'] = task
|
81 |
+
sys.modules['taskBatcher'] = taskBatcher
|
82 |
+
sys.modules['type'] = type
|
83 |
+
sys.modules['utilities'] = utilities
|
84 |
+
sys.modules['vs'] = vs
|
85 |
+
sys.modules['algolispPrimitives'] = algolispPrimitives
|
86 |
+
sys.modules['RobustFillPrimitives'] = RobustFillPrimitives
|
87 |
+
sys.modules['napsPrimitives'] = napsPrimitives
|
88 |
+
sys.modules['makeTowerTasks'] = makeTowerTasks
|
89 |
+
sys.modules['towerPrimitives'] = towerPrimitives
|
90 |
+
sys.modules['tower_common'] = tower_common
|
91 |
+
#sys.modules['tower'] = tower_main
|
92 |
+
sys.modules['groundtruthRegexes'] = groundtruthRegexes
|
93 |
+
sys.modules['regexPrimitives'] = regexPrimitives
|
94 |
+
sys.modules['makeRegexTasks'] = makeRegexTasks
|
95 |
+
#sys.modules['regexes'] = regex_main
|
96 |
+
sys.modules['deepcoderPrimitives'] = deepcoderPrimitives
|
97 |
+
sys.modules['logoPrimitives'] = logoPrimitives
|
98 |
+
sys.modules['makeLogoTasks'] = makeLogoTasks
|
99 |
+
#sys.modules['logo'] = logo_main
|
100 |
+
sys.modules['listPrimitives'] = listPrimitives
|
101 |
+
sys.modules['makeListTasks'] = makeListTasks
|
102 |
+
#sys.modules['list'] = list_main
|
103 |
+
sys.modules['arithmeticPrimitives'] = arithmeticPrimitives
|
104 |
+
sys.modules['textPrimitives'] = textPrimitives
|
105 |
+
sys.modules['makeTextTasks'] = makeTextTasks
|
106 |
+
#sys.modules['text'] = text_main
|
107 |
+
sys.modules['primitiveGraph'] = primitiveGraph
|
dreamcoder/compression.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
|
8 |
+
from dreamcoder.fragmentGrammar import FragmentGrammar
|
9 |
+
from dreamcoder.frontier import Frontier, FrontierEntry
|
10 |
+
from dreamcoder.grammar import Grammar
|
11 |
+
from dreamcoder.task import Task
|
12 |
+
from dreamcoder.program import Program, Invented
|
13 |
+
from dreamcoder.utilities import eprint, timing, callCompiled, get_root_dir
|
14 |
+
from dreamcoder.vs import induceGrammar_Beta
|
15 |
+
|
16 |
+
|
17 |
+
def induceGrammar(*args, **kwargs):
|
18 |
+
if sum(not f.empty for f in args[1]) == 0:
|
19 |
+
eprint("No nonempty frontiers, exiting grammar induction early.")
|
20 |
+
return args[0], args[1]
|
21 |
+
backend = kwargs.pop("backend", "pypy")
|
22 |
+
if 'pypy' in backend:
|
23 |
+
# pypy might not like some of the imports needed for the primitives
|
24 |
+
# but the primitive values are irrelevant for compression
|
25 |
+
# therefore strip them out and then replace them once we are done
|
26 |
+
# ditto for task data
|
27 |
+
g0,frontiers = args[0].strip_primitive_values(), \
|
28 |
+
[front.strip_primitive_values() for front in args[1]]
|
29 |
+
original_tasks = {f.task.name: f.task for f in frontiers}
|
30 |
+
frontiers = [Frontier(f.entries, Task(f.task.name,f.task.request,[]))
|
31 |
+
for f in frontiers ]
|
32 |
+
args = [g0,frontiers]
|
33 |
+
|
34 |
+
|
35 |
+
with timing("Induced a grammar"):
|
36 |
+
if backend == "pypy":
|
37 |
+
g, newFrontiers = callCompiled(pypyInduce, *args, **kwargs)
|
38 |
+
elif backend == "rust":
|
39 |
+
g, newFrontiers = rustInduce(*args, **kwargs)
|
40 |
+
elif backend == "vs":
|
41 |
+
g, newFrontiers = rustInduce(*args, vs=True, **kwargs)
|
42 |
+
elif backend == "pypy_vs":
|
43 |
+
kwargs.pop('iteration')
|
44 |
+
kwargs.pop('topk_use_only_likelihood')
|
45 |
+
fn = '/tmp/vs.pickle'
|
46 |
+
with open(fn, 'wb') as handle:
|
47 |
+
pickle.dump((args, kwargs), handle)
|
48 |
+
eprint("For debugging purposes, the version space compression invocation has been saved to", fn)
|
49 |
+
g, newFrontiers = callCompiled(induceGrammar_Beta, *args, **kwargs)
|
50 |
+
elif backend == "ocaml":
|
51 |
+
kwargs.pop('iteration')
|
52 |
+
kwargs.pop('topk_use_only_likelihood')
|
53 |
+
kwargs['topI'] = 300
|
54 |
+
kwargs['bs'] = 1000000
|
55 |
+
g, newFrontiers = ocamlInduce(*args, **kwargs)
|
56 |
+
elif backend == "memorize":
|
57 |
+
g, newFrontiers = memorizeInduce(*args, **kwargs)
|
58 |
+
else:
|
59 |
+
assert False, "unknown compressor"
|
60 |
+
|
61 |
+
if 'pypy' in backend:
|
62 |
+
g, newFrontiers = g.unstrip_primitive_values(), \
|
63 |
+
[front.unstrip_primitive_values() for front in newFrontiers]
|
64 |
+
newFrontiers = [Frontier(f.entries, original_tasks[f.task.name])
|
65 |
+
for f in newFrontiers]
|
66 |
+
|
67 |
+
|
68 |
+
return g, newFrontiers
|
69 |
+
|
70 |
+
def memorizeInduce(g, frontiers, **kwargs):
|
71 |
+
existingInventions = {p.uncurry()
|
72 |
+
for p in g.primitives }
|
73 |
+
programs = {f.bestPosterior.program for f in frontiers if not f.empty}
|
74 |
+
newInventions = programs - existingInventions
|
75 |
+
newGrammar = Grammar.uniform([p for p in g.primitives] + \
|
76 |
+
[Invented(ni) for ni in newInventions])
|
77 |
+
|
78 |
+
# rewrite in terms of new primitives
|
79 |
+
def substitute(p):
|
80 |
+
nonlocal newInventions
|
81 |
+
if p in newInventions: return Invented(p).uncurry()
|
82 |
+
return p
|
83 |
+
newFrontiers = [Frontier([FrontierEntry(program=np,
|
84 |
+
logPrior=newGrammar.logLikelihood(f.task.request, np),
|
85 |
+
logLikelihood=e.logLikelihood)
|
86 |
+
for e in f
|
87 |
+
for np in [substitute(e.program)] ],
|
88 |
+
task=f.task)
|
89 |
+
for f in frontiers ]
|
90 |
+
return newGrammar, newFrontiers
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
def pypyInduce(*args, **kwargs):
|
97 |
+
kwargs.pop('iteration')
|
98 |
+
return FragmentGrammar.induceFromFrontiers(*args, **kwargs)
|
99 |
+
|
100 |
+
|
101 |
+
def ocamlInduce(g, frontiers, _=None,
|
102 |
+
topK=1, pseudoCounts=1.0, aic=1.0,
|
103 |
+
structurePenalty=0.001, a=0, CPUs=1,
|
104 |
+
bs=1000000, topI=300):
|
105 |
+
# This is a dirty hack!
|
106 |
+
# Memory consumption increases with the number of CPUs
|
107 |
+
# And early on we have a lot of stuff to compress
|
108 |
+
# If this is the first iteration, only use a fraction of the available CPUs
|
109 |
+
if all(not p.isInvented for p in g.primitives):
|
110 |
+
if a > 3:
|
111 |
+
CPUs = max(1, int(CPUs / 6))
|
112 |
+
else:
|
113 |
+
CPUs = max(1, int(CPUs / 3))
|
114 |
+
else:
|
115 |
+
CPUs = max(1, int(CPUs / 2))
|
116 |
+
CPUs = 2
|
117 |
+
|
118 |
+
# X X X FIXME X X X
|
119 |
+
# for unknown reasons doing compression all in one go works correctly and doing it with Python and the outer loop causes problems
|
120 |
+
iterations = 99 # maximum number of components to add at once
|
121 |
+
|
122 |
+
while True:
|
123 |
+
g0 = g
|
124 |
+
|
125 |
+
originalFrontiers = frontiers
|
126 |
+
t2f = {f.task: f for f in frontiers}
|
127 |
+
frontiers = [f for f in frontiers if not f.empty]
|
128 |
+
message = {"arity": a,
|
129 |
+
"topK": topK,
|
130 |
+
"pseudoCounts": float(pseudoCounts),
|
131 |
+
"aic": aic,
|
132 |
+
"bs": bs,
|
133 |
+
"topI": topI,
|
134 |
+
"structurePenalty": float(structurePenalty),
|
135 |
+
"CPUs": CPUs,
|
136 |
+
"DSL": g.json(),
|
137 |
+
"iterations": iterations,
|
138 |
+
"frontiers": [f.json()
|
139 |
+
for f in frontiers]}
|
140 |
+
|
141 |
+
message = json.dumps(message)
|
142 |
+
if True:
|
143 |
+
timestamp = datetime.datetime.now().isoformat()
|
144 |
+
os.system("mkdir -p compressionMessages")
|
145 |
+
fn = "compressionMessages/%s" % timestamp
|
146 |
+
with open(fn, "w") as f:
|
147 |
+
f.write(message)
|
148 |
+
eprint("Compression message saved to:", fn)
|
149 |
+
|
150 |
+
try:
|
151 |
+
# Get relative path
|
152 |
+
compressor_file = os.path.join(get_root_dir(), 'compression')
|
153 |
+
process = subprocess.Popen(compressor_file,
|
154 |
+
stdin=subprocess.PIPE,
|
155 |
+
stdout=subprocess.PIPE)
|
156 |
+
response, error = process.communicate(bytes(message, encoding="utf-8"))
|
157 |
+
response = json.loads(response.decode("utf-8"))
|
158 |
+
except OSError as exc:
|
159 |
+
raise exc
|
160 |
+
|
161 |
+
g = response["DSL"]
|
162 |
+
g = Grammar(g["logVariable"],
|
163 |
+
[(l, p.infer(), p)
|
164 |
+
for production in g["productions"]
|
165 |
+
for l in [production["logProbability"]]
|
166 |
+
for p in [Program.parse(production["expression"])]],
|
167 |
+
continuationType=g0.continuationType)
|
168 |
+
|
169 |
+
frontiers = {original.task:
|
170 |
+
Frontier([FrontierEntry(p,
|
171 |
+
logLikelihood=e["logLikelihood"],
|
172 |
+
logPrior=g.logLikelihood(original.task.request, p))
|
173 |
+
for e in new["programs"]
|
174 |
+
for p in [Program.parse(e["program"])]],
|
175 |
+
task=original.task)
|
176 |
+
for original, new in zip(frontiers, response["frontiers"])}
|
177 |
+
frontiers = [frontiers.get(f.task, t2f[f.task])
|
178 |
+
for f in originalFrontiers]
|
179 |
+
if iterations == 1 and len(g) > len(g0):
|
180 |
+
eprint("Grammar changed - running another round of consolidation.")
|
181 |
+
continue
|
182 |
+
else:
|
183 |
+
eprint("Finished consolidation.")
|
184 |
+
return g, frontiers
|
185 |
+
|
186 |
+
|
187 |
+
def rustInduce(g0, frontiers, _=None,
|
188 |
+
topK=1, pseudoCounts=1.0, aic=1.0,
|
189 |
+
structurePenalty=0.001, a=0, CPUs=1, iteration=-1,
|
190 |
+
topk_use_only_likelihood=False,
|
191 |
+
vs=False):
|
192 |
+
def finite_logp(l):
|
193 |
+
return l if l != float("-inf") else -1000
|
194 |
+
|
195 |
+
message = {
|
196 |
+
"strategy": {"version-spaces": {"top_i": 50}}
|
197 |
+
if vs else
|
198 |
+
{"fragment-grammars": {}},
|
199 |
+
"params": {
|
200 |
+
"structure_penalty": structurePenalty,
|
201 |
+
"pseudocounts": int(pseudoCounts + 0.5),
|
202 |
+
"topk": topK,
|
203 |
+
"topk_use_only_likelihood": topk_use_only_likelihood,
|
204 |
+
"aic": aic if aic != float("inf") else None,
|
205 |
+
"arity": a,
|
206 |
+
},
|
207 |
+
"primitives": [{"name": p.name, "tp": str(t), "logp": finite_logp(l)}
|
208 |
+
for l, t, p in g0.productions if p.isPrimitive],
|
209 |
+
"inventions": [{"expression": str(p.body),
|
210 |
+
"logp": finite_logp(l)} # -inf=-100
|
211 |
+
for l, t, p in g0.productions if p.isInvented],
|
212 |
+
"variable_logprob": finite_logp(g0.logVariable),
|
213 |
+
"frontiers": [{
|
214 |
+
"task_tp": str(f.task.request),
|
215 |
+
"solutions": [{
|
216 |
+
"expression": str(e.program),
|
217 |
+
"logprior": finite_logp(e.logPrior),
|
218 |
+
"loglikelihood": e.logLikelihood,
|
219 |
+
} for e in f],
|
220 |
+
} for f in frontiers],
|
221 |
+
}
|
222 |
+
|
223 |
+
eprint("running rust compressor")
|
224 |
+
|
225 |
+
messageJson = json.dumps(message)
|
226 |
+
|
227 |
+
with open("jsonDebug", "w") as f:
|
228 |
+
f.write(messageJson)
|
229 |
+
|
230 |
+
# check which version of python we are using
|
231 |
+
# if >=3.6 do:
|
232 |
+
if sys.version_info[1] >= 6:
|
233 |
+
p = subprocess.Popen(
|
234 |
+
['./rust_compressor/rust_compressor'],
|
235 |
+
encoding='utf-8',
|
236 |
+
stdin=subprocess.PIPE,
|
237 |
+
stdout=subprocess.PIPE)
|
238 |
+
elif sys.version_info[1] == 5:
|
239 |
+
p = subprocess.Popen(
|
240 |
+
['./rust_compressor/rust_compressor'],
|
241 |
+
stdin=subprocess.PIPE,
|
242 |
+
stdout=subprocess.PIPE)
|
243 |
+
|
244 |
+
messageJson = bytearray(messageJson, encoding='utf-8')
|
245 |
+
# convert messageJson string to bytes
|
246 |
+
else:
|
247 |
+
eprint("must be python 3.5 or 3.6")
|
248 |
+
assert False
|
249 |
+
|
250 |
+
p.stdin.write(messageJson)
|
251 |
+
p.stdin.flush()
|
252 |
+
p.stdin.close()
|
253 |
+
|
254 |
+
if p.returncode is not None:
|
255 |
+
raise ValueError("rust compressor failed")
|
256 |
+
|
257 |
+
if sys.version_info[1] >= 6:
|
258 |
+
resp = json.load(p.stdout)
|
259 |
+
elif sys.version_info[1] == 5:
|
260 |
+
import codecs
|
261 |
+
resp = json.load(codecs.getreader('utf-8')(p.stdout))
|
262 |
+
|
263 |
+
productions = [(x["logp"], p) for p, x in
|
264 |
+
zip((p for (_, _, p) in g0.productions if p.isPrimitive), resp["primitives"])] + \
|
265 |
+
[(i["logp"], Invented(Program.parse(i["expression"])))
|
266 |
+
for i in resp["inventions"]]
|
267 |
+
productions = [(l if l is not None else float("-inf"), p)
|
268 |
+
for l, p in productions]
|
269 |
+
g = Grammar.fromProductions(productions, resp["variable_logprob"], continuationType=g0.continuationType)
|
270 |
+
newFrontiers = [
|
271 |
+
Frontier(
|
272 |
+
[
|
273 |
+
FrontierEntry(
|
274 |
+
Program.parse(
|
275 |
+
s["expression"]),
|
276 |
+
logPrior=s["logprior"],
|
277 |
+
logLikelihood=s["loglikelihood"]) for s in r["solutions"]],
|
278 |
+
f.task) for f,
|
279 |
+
r in zip(
|
280 |
+
frontiers,
|
281 |
+
resp["frontiers"])]
|
282 |
+
return g, newFrontiers
|
dreamcoder/deprecated/__init__.py
ADDED
File without changes
|
dreamcoder/deprecated/network.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Deprecated network.py module. This file only exists to support backwards-compatibility
|
3 |
+
with old pickle files. See lib/__init__.py for more information.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import print_function
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.autograd import Variable
|
12 |
+
from torch.nn.parameter import Parameter
|
13 |
+
|
14 |
+
|
15 |
+
# UPGRADING TO INPUT -> OUTPUT -> TARGET
|
16 |
+
# Todo:
|
17 |
+
# [X] Output attending to input
|
18 |
+
# [X] Target attending to output
|
19 |
+
# [ ] check passing hidden state between encoders/decoder (+ pass c?)
|
20 |
+
# [ ] add v_output
|
21 |
+
|
22 |
+
|
23 |
+
def choose(matrix, idxs):
|
24 |
+
if isinstance(idxs, Variable):
|
25 |
+
idxs = idxs.data
|
26 |
+
assert(matrix.ndimension() == 2)
|
27 |
+
unrolled_idxs = idxs + \
|
28 |
+
torch.arange(0, matrix.size(0)).type_as(idxs) * matrix.size(1)
|
29 |
+
return matrix.view(matrix.nelement())[unrolled_idxs]
|
30 |
+
|
31 |
+
|
32 |
+
class Network(nn.Module):
|
33 |
+
"""
|
34 |
+
Todo:
|
35 |
+
- Beam search
|
36 |
+
- check if this is right? attend during P->FC rather than during softmax->P?
|
37 |
+
- allow length 0 inputs/targets
|
38 |
+
- give n_examples as input to FC
|
39 |
+
- Initialise new weights randomly, rather than as zeroes
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
input_vocabulary,
|
45 |
+
target_vocabulary,
|
46 |
+
hidden_size=512,
|
47 |
+
embedding_size=128,
|
48 |
+
cell_type="LSTM"):
|
49 |
+
"""
|
50 |
+
:param list input_vocabulary: list of possible inputs
|
51 |
+
:param list target_vocabulary: list of possible targets
|
52 |
+
"""
|
53 |
+
super(Network, self).__init__()
|
54 |
+
self.h_input_encoder_size = hidden_size
|
55 |
+
self.h_output_encoder_size = hidden_size
|
56 |
+
self.h_decoder_size = hidden_size
|
57 |
+
self.embedding_size = embedding_size
|
58 |
+
self.input_vocabulary = input_vocabulary
|
59 |
+
self.target_vocabulary = target_vocabulary
|
60 |
+
# Number of tokens in input vocabulary
|
61 |
+
self.v_input = len(input_vocabulary)
|
62 |
+
# Number of tokens in target vocabulary
|
63 |
+
self.v_target = len(target_vocabulary)
|
64 |
+
|
65 |
+
self.cell_type = cell_type
|
66 |
+
if cell_type == 'GRU':
|
67 |
+
self.input_encoder_cell = nn.GRUCell(
|
68 |
+
input_size=self.v_input + 1,
|
69 |
+
hidden_size=self.h_input_encoder_size,
|
70 |
+
bias=True)
|
71 |
+
self.input_encoder_init = Parameter(
|
72 |
+
torch.rand(1, self.h_input_encoder_size))
|
73 |
+
self.output_encoder_cell = nn.GRUCell(
|
74 |
+
input_size=self.v_input +
|
75 |
+
1 +
|
76 |
+
self.h_input_encoder_size,
|
77 |
+
hidden_size=self.h_output_encoder_size,
|
78 |
+
bias=True)
|
79 |
+
self.decoder_cell = nn.GRUCell(
|
80 |
+
input_size=self.v_target + 1,
|
81 |
+
hidden_size=self.h_decoder_size,
|
82 |
+
bias=True)
|
83 |
+
if cell_type == 'LSTM':
|
84 |
+
self.input_encoder_cell = nn.LSTMCell(
|
85 |
+
input_size=self.v_input + 1,
|
86 |
+
hidden_size=self.h_input_encoder_size,
|
87 |
+
bias=True)
|
88 |
+
self.input_encoder_init = nn.ParameterList([Parameter(torch.rand(
|
89 |
+
1, self.h_input_encoder_size)), Parameter(torch.rand(1, self.h_input_encoder_size))])
|
90 |
+
self.output_encoder_cell = nn.LSTMCell(
|
91 |
+
input_size=self.v_input +
|
92 |
+
1 +
|
93 |
+
self.h_input_encoder_size,
|
94 |
+
hidden_size=self.h_output_encoder_size,
|
95 |
+
bias=True)
|
96 |
+
self.output_encoder_init_c = Parameter(
|
97 |
+
torch.rand(1, self.h_output_encoder_size))
|
98 |
+
self.decoder_cell = nn.LSTMCell(
|
99 |
+
input_size=self.v_target + 1,
|
100 |
+
hidden_size=self.h_decoder_size,
|
101 |
+
bias=True)
|
102 |
+
self.decoder_init_c = Parameter(torch.rand(1, self.h_decoder_size))
|
103 |
+
|
104 |
+
self.W = nn.Linear(
|
105 |
+
self.h_output_encoder_size +
|
106 |
+
self.h_decoder_size,
|
107 |
+
self.embedding_size)
|
108 |
+
self.V = nn.Linear(self.embedding_size, self.v_target + 1)
|
109 |
+
self.input_A = nn.Bilinear(
|
110 |
+
self.h_input_encoder_size,
|
111 |
+
self.h_output_encoder_size,
|
112 |
+
1,
|
113 |
+
bias=False)
|
114 |
+
self.output_A = nn.Bilinear(
|
115 |
+
self.h_output_encoder_size,
|
116 |
+
self.h_decoder_size,
|
117 |
+
1,
|
118 |
+
bias=False)
|
119 |
+
self.input_EOS = torch.zeros(1, self.v_input + 1)
|
120 |
+
self.input_EOS[:, -1] = 1
|
121 |
+
self.input_EOS = Parameter(self.input_EOS)
|
122 |
+
self.output_EOS = torch.zeros(1, self.v_input + 1)
|
123 |
+
self.output_EOS[:, -1] = 1
|
124 |
+
self.output_EOS = Parameter(self.output_EOS)
|
125 |
+
self.target_EOS = torch.zeros(1, self.v_target + 1)
|
126 |
+
self.target_EOS[:, -1] = 1
|
127 |
+
self.target_EOS = Parameter(self.target_EOS)
|
128 |
+
|
129 |
+
def __getstate__(self):
|
130 |
+
if hasattr(self, 'opt'):
|
131 |
+
return dict([(k, v) for k, v in self.__dict__.items(
|
132 |
+
) if k is not 'opt'] + [('optstate', self.opt.state_dict())])
|
133 |
+
# return {**{k:v for k,v in self.__dict__.items() if k is not 'opt'},
|
134 |
+
# 'optstate': self.opt.state_dict()}
|
135 |
+
else:
|
136 |
+
return self.__dict__
|
137 |
+
|
138 |
+
def __setstate__(self, state):
|
139 |
+
self.__dict__.update(state)
|
140 |
+
# Legacy:
|
141 |
+
if isinstance(self.input_encoder_init, tuple):
|
142 |
+
self.input_encoder_init = nn.ParameterList(
|
143 |
+
list(self.input_encoder_init))
|
144 |
+
|
145 |
+
def clear_optimiser(self):
|
146 |
+
if hasattr(self, 'opt'):
|
147 |
+
del self.opt
|
148 |
+
if hasattr(self, 'optstate'):
|
149 |
+
del self.optstate
|
150 |
+
|
151 |
+
def get_optimiser(self):
|
152 |
+
self.opt = torch.optim.Adam(self.parameters(), lr=0.001)
|
153 |
+
if hasattr(self, 'optstate'):
|
154 |
+
self.opt.load_state_dict(self.optstate)
|
155 |
+
|
156 |
+
def optimiser_step(self, inputs, outputs, target):
|
157 |
+
if not hasattr(self, 'opt'):
|
158 |
+
self.get_optimiser()
|
159 |
+
score = self.score(inputs, outputs, target, autograd=True).mean()
|
160 |
+
(-score).backward()
|
161 |
+
self.opt.step()
|
162 |
+
self.opt.zero_grad()
|
163 |
+
return score.data[0]
|
164 |
+
|
165 |
+
def set_target_vocabulary(self, target_vocabulary):
|
166 |
+
if target_vocabulary == self.target_vocabulary:
|
167 |
+
return
|
168 |
+
|
169 |
+
V_weight = []
|
170 |
+
V_bias = []
|
171 |
+
decoder_ih = []
|
172 |
+
|
173 |
+
for i in range(len(target_vocabulary)):
|
174 |
+
if target_vocabulary[i] in self.target_vocabulary:
|
175 |
+
j = self.target_vocabulary.index(target_vocabulary[i])
|
176 |
+
V_weight.append(self.V.weight.data[j:j + 1])
|
177 |
+
V_bias.append(self.V.bias.data[j:j + 1])
|
178 |
+
decoder_ih.append(self.decoder_cell.weight_ih.data[:, j:j + 1])
|
179 |
+
else:
|
180 |
+
V_weight.append(torch.zeros(1, self.V.weight.size(1)))
|
181 |
+
V_bias.append(torch.ones(1) * -10)
|
182 |
+
decoder_ih.append(
|
183 |
+
torch.zeros(
|
184 |
+
self.decoder_cell.weight_ih.data.size(0), 1))
|
185 |
+
|
186 |
+
V_weight.append(self.V.weight.data[-1:])
|
187 |
+
V_bias.append(self.V.bias.data[-1:])
|
188 |
+
decoder_ih.append(self.decoder_cell.weight_ih.data[:, -1:])
|
189 |
+
|
190 |
+
self.target_vocabulary = target_vocabulary
|
191 |
+
self.v_target = len(target_vocabulary)
|
192 |
+
self.target_EOS.data = torch.zeros(1, self.v_target + 1)
|
193 |
+
self.target_EOS.data[:, -1] = 1
|
194 |
+
|
195 |
+
self.V.weight.data = torch.cat(V_weight, dim=0)
|
196 |
+
self.V.bias.data = torch.cat(V_bias, dim=0)
|
197 |
+
self.V.out_features = self.V.bias.data.size(0)
|
198 |
+
|
199 |
+
self.decoder_cell.weight_ih.data = torch.cat(decoder_ih, dim=1)
|
200 |
+
self.decoder_cell.input_size = self.decoder_cell.weight_ih.data.size(1)
|
201 |
+
|
202 |
+
self.clear_optimiser()
|
203 |
+
|
204 |
+
def input_encoder_get_init(self, batch_size):
|
205 |
+
if self.cell_type == "GRU":
|
206 |
+
return self.input_encoder_init.repeat(batch_size, 1)
|
207 |
+
if self.cell_type == "LSTM":
|
208 |
+
return tuple(x.repeat(batch_size, 1)
|
209 |
+
for x in self.input_encoder_init)
|
210 |
+
|
211 |
+
def output_encoder_get_init(self, input_encoder_h):
|
212 |
+
if self.cell_type == "GRU":
|
213 |
+
return input_encoder_h
|
214 |
+
if self.cell_type == "LSTM":
|
215 |
+
return (
|
216 |
+
input_encoder_h,
|
217 |
+
self.output_encoder_init_c.repeat(
|
218 |
+
input_encoder_h.size(0),
|
219 |
+
1))
|
220 |
+
|
221 |
+
def decoder_get_init(self, output_encoder_h):
|
222 |
+
if self.cell_type == "GRU":
|
223 |
+
return output_encoder_h
|
224 |
+
if self.cell_type == "LSTM":
|
225 |
+
return (
|
226 |
+
output_encoder_h,
|
227 |
+
self.decoder_init_c.repeat(
|
228 |
+
output_encoder_h.size(0),
|
229 |
+
1))
|
230 |
+
|
231 |
+
def cell_get_h(self, cell_state):
|
232 |
+
if self.cell_type == "GRU":
|
233 |
+
return cell_state
|
234 |
+
if self.cell_type == "LSTM":
|
235 |
+
return cell_state[0]
|
236 |
+
|
237 |
+
def score(self, inputs, outputs, target, autograd=False):
|
238 |
+
inputs = self.inputsToTensors(inputs)
|
239 |
+
outputs = self.inputsToTensors(outputs)
|
240 |
+
target = self.targetToTensor(target)
|
241 |
+
target, score = self.run(inputs, outputs, target=target, mode="score")
|
242 |
+
# target = self.tensorToOutput(target)
|
243 |
+
if autograd:
|
244 |
+
return score
|
245 |
+
else:
|
246 |
+
return score.data
|
247 |
+
|
248 |
+
def sample(self, inputs, outputs):
|
249 |
+
inputs = self.inputsToTensors(inputs)
|
250 |
+
outputs = self.inputsToTensors(outputs)
|
251 |
+
target, score = self.run(inputs, outputs, mode="sample")
|
252 |
+
target = self.tensorToOutput(target)
|
253 |
+
return target
|
254 |
+
|
255 |
+
def sampleAndScore(self, inputs, outputs, nRepeats=None):
|
256 |
+
inputs = self.inputsToTensors(inputs)
|
257 |
+
outputs = self.inputsToTensors(outputs)
|
258 |
+
if nRepeats is None:
|
259 |
+
target, score = self.run(inputs, outputs, mode="sample")
|
260 |
+
target = self.tensorToOutput(target)
|
261 |
+
return target, score.data
|
262 |
+
else:
|
263 |
+
target = []
|
264 |
+
score = []
|
265 |
+
for i in range(nRepeats):
|
266 |
+
# print("repeat %d" % i)
|
267 |
+
t, s = self.run(inputs, outputs, mode="sample")
|
268 |
+
t = self.tensorToOutput(t)
|
269 |
+
target.extend(t)
|
270 |
+
score.extend(list(s.data))
|
271 |
+
return target, score
|
272 |
+
|
273 |
+
def run(self, inputs, outputs, target=None, mode="sample"):
|
274 |
+
"""
|
275 |
+
:param mode: "score" returns log p(target|input), "sample" returns target ~ p(-|input)
|
276 |
+
:param List[LongTensor] inputs: n_examples * (max_length_input * batch_size)
|
277 |
+
:param List[LongTensor] target: max_length_target * batch_size
|
278 |
+
"""
|
279 |
+
assert((mode == "score" and target is not None) or mode == "sample")
|
280 |
+
|
281 |
+
n_examples = len(inputs)
|
282 |
+
max_length_input = [inputs[j].size(0) for j in range(n_examples)]
|
283 |
+
max_length_output = [outputs[j].size(0) for j in range(n_examples)]
|
284 |
+
max_length_target = target.size(0) if target is not None else 10
|
285 |
+
batch_size = inputs[0].size(1)
|
286 |
+
|
287 |
+
score = Variable(torch.zeros(batch_size))
|
288 |
+
inputs_scatter = [Variable(torch.zeros(max_length_input[j], batch_size, self.v_input + 1).scatter_(
|
289 |
+
2, inputs[j][:, :, None], 1)) for j in range(n_examples)] # n_examples * (max_length_input * batch_size * v_input+1)
|
290 |
+
outputs_scatter = [Variable(torch.zeros(max_length_output[j], batch_size, self.v_input + 1).scatter_(
|
291 |
+
2, outputs[j][:, :, None], 1)) for j in range(n_examples)] # n_examples * (max_length_output * batch_size * v_input+1)
|
292 |
+
if target is not None:
|
293 |
+
target_scatter = Variable(torch.zeros(max_length_target,
|
294 |
+
batch_size,
|
295 |
+
self.v_target + 1).scatter_(2,
|
296 |
+
target[:,
|
297 |
+
:,
|
298 |
+
None],
|
299 |
+
1)) # max_length_target * batch_size * v_target+1
|
300 |
+
|
301 |
+
# -------------- Input Encoder -------------
|
302 |
+
|
303 |
+
# n_examples * (max_length_input * batch_size * h_encoder_size)
|
304 |
+
input_H = []
|
305 |
+
input_embeddings = [] # h for example at INPUT_EOS
|
306 |
+
# 0 until (and including) INPUT_EOS, then -inf
|
307 |
+
input_attention_mask = []
|
308 |
+
for j in range(n_examples):
|
309 |
+
active = torch.Tensor(max_length_input[j], batch_size).byte()
|
310 |
+
active[0, :] = 1
|
311 |
+
state = self.input_encoder_get_init(batch_size)
|
312 |
+
hs = []
|
313 |
+
for i in range(max_length_input[j]):
|
314 |
+
state = self.input_encoder_cell(
|
315 |
+
inputs_scatter[j][i, :, :], state)
|
316 |
+
if i + 1 < max_length_input[j]:
|
317 |
+
active[i + 1, :] = active[i, :] * \
|
318 |
+
(inputs[j][i, :] != self.v_input)
|
319 |
+
h = self.cell_get_h(state)
|
320 |
+
hs.append(h[None, :, :])
|
321 |
+
input_H.append(torch.cat(hs, 0))
|
322 |
+
embedding_idx = active.sum(0).long() - 1
|
323 |
+
embedding = input_H[j].gather(0, Variable(
|
324 |
+
embedding_idx[None, :, None].repeat(1, 1, self.h_input_encoder_size)))[0]
|
325 |
+
input_embeddings.append(embedding)
|
326 |
+
input_attention_mask.append(Variable(active.float().log()))
|
327 |
+
|
328 |
+
# -------------- Output Encoder -------------
|
329 |
+
|
330 |
+
def input_attend(j, h_out):
|
331 |
+
"""
|
332 |
+
'general' attention from https://arxiv.org/pdf/1508.04025.pdf
|
333 |
+
:param j: Index of example
|
334 |
+
:param h_out: batch_size * h_output_encoder_size
|
335 |
+
"""
|
336 |
+
scores = self.input_A(
|
337 |
+
input_H[j].view(
|
338 |
+
max_length_input[j] * batch_size,
|
339 |
+
self.h_input_encoder_size),
|
340 |
+
h_out.view(
|
341 |
+
batch_size,
|
342 |
+
self.h_output_encoder_size).repeat(
|
343 |
+
max_length_input[j],
|
344 |
+
1)).view(
|
345 |
+
max_length_input[j],
|
346 |
+
batch_size) + input_attention_mask[j]
|
347 |
+
c = (F.softmax(scores[:, :, None], dim=0) * input_H[j]).sum(0)
|
348 |
+
return c
|
349 |
+
|
350 |
+
# n_examples * (max_length_input * batch_size * h_encoder_size)
|
351 |
+
output_H = []
|
352 |
+
output_embeddings = [] # h for example at INPUT_EOS
|
353 |
+
# 0 until (and including) INPUT_EOS, then -inf
|
354 |
+
output_attention_mask = []
|
355 |
+
for j in range(n_examples):
|
356 |
+
active = torch.Tensor(max_length_output[j], batch_size).byte()
|
357 |
+
active[0, :] = 1
|
358 |
+
state = self.output_encoder_get_init(input_embeddings[j])
|
359 |
+
hs = []
|
360 |
+
h = self.cell_get_h(state)
|
361 |
+
for i in range(max_length_output[j]):
|
362 |
+
state = self.output_encoder_cell(torch.cat(
|
363 |
+
[outputs_scatter[j][i, :, :], input_attend(j, h)], 1), state)
|
364 |
+
if i + 1 < max_length_output[j]:
|
365 |
+
active[i + 1, :] = active[i, :] * \
|
366 |
+
(outputs[j][i, :] != self.v_input)
|
367 |
+
h = self.cell_get_h(state)
|
368 |
+
hs.append(h[None, :, :])
|
369 |
+
output_H.append(torch.cat(hs, 0))
|
370 |
+
embedding_idx = active.sum(0).long() - 1
|
371 |
+
embedding = output_H[j].gather(0, Variable(
|
372 |
+
embedding_idx[None, :, None].repeat(1, 1, self.h_output_encoder_size)))[0]
|
373 |
+
output_embeddings.append(embedding)
|
374 |
+
output_attention_mask.append(Variable(active.float().log()))
|
375 |
+
|
376 |
+
# ------------------ Decoder -----------------
|
377 |
+
|
378 |
+
def output_attend(j, h_dec):
|
379 |
+
"""
|
380 |
+
'general' attention from https://arxiv.org/pdf/1508.04025.pdf
|
381 |
+
:param j: Index of example
|
382 |
+
:param h_dec: batch_size * h_decoder_size
|
383 |
+
"""
|
384 |
+
scores = self.output_A(
|
385 |
+
output_H[j].view(
|
386 |
+
max_length_output[j] * batch_size,
|
387 |
+
self.h_output_encoder_size),
|
388 |
+
h_dec.view(
|
389 |
+
batch_size,
|
390 |
+
self.h_decoder_size).repeat(
|
391 |
+
max_length_output[j],
|
392 |
+
1)).view(
|
393 |
+
max_length_output[j],
|
394 |
+
batch_size) + output_attention_mask[j]
|
395 |
+
c = (F.softmax(scores[:, :, None], dim=0) * output_H[j]).sum(0)
|
396 |
+
return c
|
397 |
+
|
398 |
+
# Multi-example pooling: Figure 3, https://arxiv.org/pdf/1703.07469.pdf
|
399 |
+
target = target if mode == "score" else torch.zeros(
|
400 |
+
max_length_target, batch_size).long()
|
401 |
+
decoder_states = [
|
402 |
+
self.decoder_get_init(
|
403 |
+
output_embeddings[j]) for j in range(n_examples)] # P
|
404 |
+
active = torch.ones(batch_size).byte()
|
405 |
+
for i in range(max_length_target):
|
406 |
+
FC = []
|
407 |
+
for j in range(n_examples):
|
408 |
+
h = self.cell_get_h(decoder_states[j])
|
409 |
+
p_aug = torch.cat([h, output_attend(j, h)], 1)
|
410 |
+
FC.append(F.tanh(self.W(p_aug)[None, :, :]))
|
411 |
+
# batch_size * embedding_size
|
412 |
+
m = torch.max(torch.cat(FC, 0), 0)[0]
|
413 |
+
logsoftmax = F.log_softmax(self.V(m), dim=1)
|
414 |
+
if mode == "sample":
|
415 |
+
target[i, :] = torch.multinomial(
|
416 |
+
logsoftmax.data.exp(), 1)[:, 0]
|
417 |
+
score = score + \
|
418 |
+
choose(logsoftmax, target[i, :]) * Variable(active.float())
|
419 |
+
active *= (target[i, :] != self.v_target)
|
420 |
+
for j in range(n_examples):
|
421 |
+
if mode == "score":
|
422 |
+
target_char_scatter = target_scatter[i, :, :]
|
423 |
+
elif mode == "sample":
|
424 |
+
target_char_scatter = Variable(torch.zeros(
|
425 |
+
batch_size, self.v_target + 1).scatter_(1, target[i, :, None], 1))
|
426 |
+
decoder_states[j] = self.decoder_cell(
|
427 |
+
target_char_scatter, decoder_states[j])
|
428 |
+
return target, score
|
429 |
+
|
430 |
+
def inputsToTensors(self, inputss):
|
431 |
+
"""
|
432 |
+
:param inputss: size = nBatch * nExamples
|
433 |
+
"""
|
434 |
+
tensors = []
|
435 |
+
for j in range(len(inputss[0])):
|
436 |
+
inputs = [x[j] for x in inputss]
|
437 |
+
maxlen = max(len(s) for s in inputs)
|
438 |
+
t = torch.ones(
|
439 |
+
1 if maxlen == 0 else maxlen + 1,
|
440 |
+
len(inputs)).long() * self.v_input
|
441 |
+
for i in range(len(inputs)):
|
442 |
+
s = inputs[i]
|
443 |
+
if len(s) > 0:
|
444 |
+
t[:len(s), i] = torch.LongTensor(
|
445 |
+
[self.input_vocabulary.index(x) for x in s])
|
446 |
+
tensors.append(t)
|
447 |
+
return tensors
|
448 |
+
|
449 |
+
def targetToTensor(self, targets):
|
450 |
+
"""
|
451 |
+
:param targets:
|
452 |
+
"""
|
453 |
+
maxlen = max(len(s) for s in targets)
|
454 |
+
t = torch.ones(
|
455 |
+
1 if maxlen == 0 else maxlen + 1,
|
456 |
+
len(targets)).long() * self.v_target
|
457 |
+
for i in range(len(targets)):
|
458 |
+
s = targets[i]
|
459 |
+
if len(s) > 0:
|
460 |
+
t[:len(s), i] = torch.LongTensor(
|
461 |
+
[self.target_vocabulary.index(x) for x in s])
|
462 |
+
return t
|
463 |
+
|
464 |
+
def tensorToOutput(self, tensor):
|
465 |
+
"""
|
466 |
+
:param tensor: max_length * batch_size
|
467 |
+
"""
|
468 |
+
out = []
|
469 |
+
for i in range(tensor.size(1)):
|
470 |
+
l = tensor[:, i].tolist()
|
471 |
+
if l[0] == self.v_target:
|
472 |
+
out.append([])
|
473 |
+
elif self.v_target in l:
|
474 |
+
final = tensor[:, i].tolist().index(self.v_target)
|
475 |
+
out.append([self.target_vocabulary[x]
|
476 |
+
for x in tensor[:final, i]])
|
477 |
+
else:
|
478 |
+
out.append([self.target_vocabulary[x] for x in tensor[:, i]])
|
479 |
+
return out
|
dreamcoder/differentiation.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
from dreamcoder.utilities import *
|
4 |
+
|
5 |
+
|
6 |
+
class InvalidLoss(Exception):
|
7 |
+
pass
|
8 |
+
|
9 |
+
|
10 |
+
class DN(object):
|
11 |
+
'''differentiable node: parent object of every differentiable operation'''
|
12 |
+
|
13 |
+
def __init__(self, arguments):
|
14 |
+
self.gradient = None
|
15 |
+
if arguments != []:
|
16 |
+
self.data = None
|
17 |
+
self.arguments = arguments
|
18 |
+
|
19 |
+
# descendents: every variable that takes this variable as input
|
20 |
+
# descendents: [(DN,float)]
|
21 |
+
# the additional float parameter is d Descendent / d This
|
22 |
+
self.descendents = []
|
23 |
+
|
24 |
+
self.recalculate()
|
25 |
+
|
26 |
+
def __str__(self):
|
27 |
+
if self.arguments == []:
|
28 |
+
return self.name
|
29 |
+
return "(%s %s)" % (self.name, " ".join(str(x)
|
30 |
+
for x in self.arguments))
|
31 |
+
|
32 |
+
def __repr__(self):
|
33 |
+
return "DN(op = %s, data = %s, grad = %s, #descendents = %d, args = %s)" % (
|
34 |
+
self.name, self.data, self.gradient, len(self.descendents), self.arguments)
|
35 |
+
|
36 |
+
@property
|
37 |
+
def derivative(self): return self.differentiate()
|
38 |
+
|
39 |
+
def differentiate(self):
|
40 |
+
if self.gradient is None:
|
41 |
+
self.gradient = sum(partial * descendent.differentiate()
|
42 |
+
for descendent, partial in self.descendents)
|
43 |
+
return self.gradient
|
44 |
+
|
45 |
+
def zeroEverything(self):
|
46 |
+
if self.gradient is None and self.descendents == [] and (
|
47 |
+
self.data is None or self.arguments == []):
|
48 |
+
return
|
49 |
+
|
50 |
+
self.gradient = None
|
51 |
+
self.descendents = []
|
52 |
+
if self.arguments != []:
|
53 |
+
self.data = None
|
54 |
+
|
55 |
+
for x in self.arguments:
|
56 |
+
x.zeroEverything()
|
57 |
+
|
58 |
+
def lightweightRecalculate(self):
|
59 |
+
return self.forward(*[a.lightweightRecalculate()
|
60 |
+
for a in self.arguments])
|
61 |
+
|
62 |
+
def recalculate(self):
|
63 |
+
if self.data is None:
|
64 |
+
inputs = [a.recalculate() for a in self.arguments]
|
65 |
+
self.data = self.forward(*inputs)
|
66 |
+
# if invalid(self.data):
|
67 |
+
# eprint("I am invalid",repr(self))
|
68 |
+
# eprint("Here are my inputs",inputs)
|
69 |
+
# self.zeroEverything()
|
70 |
+
# eprint("Here I am after being zeroed",repr(self))
|
71 |
+
# raise Exception('invalid loss')
|
72 |
+
#assert valid(self.data)
|
73 |
+
partials = self.backward(*inputs)
|
74 |
+
for d, a in zip(partials, self.arguments):
|
75 |
+
# if invalid(d):
|
76 |
+
# eprint("I have an invalid derivative",self)
|
77 |
+
# eprint("Inputs",inputs)
|
78 |
+
# eprint("partials",partials)
|
79 |
+
# raise Exception('invalid derivative')
|
80 |
+
a.descendents.append((self, d))
|
81 |
+
return self.data
|
82 |
+
|
83 |
+
def backPropagation(self):
|
84 |
+
self.gradient = 1.
|
85 |
+
self.recursivelyDifferentiate()
|
86 |
+
|
87 |
+
def recursivelyDifferentiate(self):
|
88 |
+
self.differentiate()
|
89 |
+
for x in self.arguments:
|
90 |
+
x.recursivelyDifferentiate()
|
91 |
+
|
92 |
+
def updateNetwork(self):
|
93 |
+
self.zeroEverything()
|
94 |
+
l = self.recalculate()
|
95 |
+
self.backPropagation()
|
96 |
+
return l
|
97 |
+
|
98 |
+
def log(self): return Logarithm(self)
|
99 |
+
|
100 |
+
def square(self): return Square(self)
|
101 |
+
|
102 |
+
def exp(self): return Exponentiation(self)
|
103 |
+
|
104 |
+
def clamp(self, l, u): return Clamp(self, l, u)
|
105 |
+
|
106 |
+
def __abs__(self): return AbsoluteValue(self)
|
107 |
+
|
108 |
+
def __add__(self, o): return Addition(self, Placeholder.maybe(o))
|
109 |
+
|
110 |
+
def __radd__(self, o): return Addition(self, Placeholder.maybe(o))
|
111 |
+
|
112 |
+
def __sub__(self, o): return Subtraction(self, Placeholder.maybe(o))
|
113 |
+
|
114 |
+
def __rsub__(self, o): return Subtraction(Placeholder.maybe(o), self)
|
115 |
+
|
116 |
+
def __mul__(self, o): return Multiplication(self, Placeholder.maybe(o))
|
117 |
+
|
118 |
+
def __rmul__(self, o): return Multiplication(self, Placeholder.maybe(o))
|
119 |
+
|
120 |
+
def __neg__(self): return Negation(self)
|
121 |
+
|
122 |
+
def __truediv__(self, o): return Division(self, Placeholder.maybe(o))
|
123 |
+
|
124 |
+
def __rtruediv__(self, o): return Division(Placeholder.maybe(o), self)
|
125 |
+
|
126 |
+
def numericallyVerifyGradients(self, parameters):
|
127 |
+
calculatedGradients = [p.derivative for p in parameters]
|
128 |
+
e = 0.00001
|
129 |
+
for j, p in enumerate(parameters):
|
130 |
+
p.data -= e
|
131 |
+
y1 = self.lightweightRecalculate()
|
132 |
+
p.data += 2 * e
|
133 |
+
y2 = self.lightweightRecalculate()
|
134 |
+
p.data -= e
|
135 |
+
d = (y2 - y1) / (2 * e)
|
136 |
+
if abs(calculatedGradients[j] - d) > 0.1:
|
137 |
+
eprint(
|
138 |
+
"Bad gradient: expected %f, got %f" %
|
139 |
+
(d, calculatedGradients[j]))
|
140 |
+
|
141 |
+
def gradientDescent(
|
142 |
+
self,
|
143 |
+
parameters,
|
144 |
+
_=None,
|
145 |
+
lr=0.001,
|
146 |
+
steps=10**3,
|
147 |
+
update=None):
|
148 |
+
for j in range(steps):
|
149 |
+
l = self.updateNetwork()
|
150 |
+
if update is not None and j % update == 0:
|
151 |
+
eprint("LOSS:", l)
|
152 |
+
for p in parameters:
|
153 |
+
eprint(p.data, '\t', p.derivative)
|
154 |
+
if invalid(l):
|
155 |
+
raise InvalidLoss()
|
156 |
+
|
157 |
+
for p in parameters:
|
158 |
+
p.data -= lr * p.derivative
|
159 |
+
return self.data
|
160 |
+
|
161 |
+
def restartingOptimize(self, parameters, _=None, attempts=1,
|
162 |
+
s=1., decay=0.5, grow=0.1,
|
163 |
+
lr=0.1, steps=10**3, update=None):
|
164 |
+
ls = []
|
165 |
+
for _ in range(attempts):
|
166 |
+
for p in parameters:
|
167 |
+
p.data = random.random()*10 - 5
|
168 |
+
ls.append(
|
169 |
+
self.resilientBackPropagation(
|
170 |
+
parameters, lr=lr, steps=steps,
|
171 |
+
decay=decay, grow=grow))
|
172 |
+
return min(ls)
|
173 |
+
|
174 |
+
def resilientBackPropagation(
|
175 |
+
self,
|
176 |
+
parameters,
|
177 |
+
_=None,
|
178 |
+
decay=0.5,
|
179 |
+
grow=1.2,
|
180 |
+
lr=0.1,
|
181 |
+
steps=10**3,
|
182 |
+
update=None):
|
183 |
+
previousSign = [None] * len(parameters)
|
184 |
+
lr = [lr] * len(parameters)
|
185 |
+
for j in range(steps):
|
186 |
+
l = self.updateNetwork()
|
187 |
+
|
188 |
+
if update is not None and j % update == 0:
|
189 |
+
eprint("LOSS:", l)
|
190 |
+
eprint("\t".join(str(p.derivative) for p in parameters))
|
191 |
+
if invalid(l):
|
192 |
+
raise InvalidLoss()
|
193 |
+
|
194 |
+
newSigns = [p.derivative > 0 for p in parameters]
|
195 |
+
for i, p in enumerate(parameters):
|
196 |
+
if p.derivative > 0:
|
197 |
+
p.data -= lr[i]
|
198 |
+
elif p.derivative < 0:
|
199 |
+
p.data += lr[i]
|
200 |
+
if previousSign[i] is not None:
|
201 |
+
if previousSign[i] == newSigns[i]:
|
202 |
+
lr[i] *= grow
|
203 |
+
else:
|
204 |
+
lr[i] *= decay
|
205 |
+
previousSign = newSigns
|
206 |
+
|
207 |
+
return self.data
|
208 |
+
|
209 |
+
|
210 |
+
class Placeholder(DN):
|
211 |
+
COUNTER = 0
|
212 |
+
|
213 |
+
def __init__(self, initialValue=0., name=None):
|
214 |
+
self.data = initialValue
|
215 |
+
super(Placeholder, self).__init__([])
|
216 |
+
if name is None:
|
217 |
+
name = "p_" + str(Placeholder.COUNTER)
|
218 |
+
Placeholder.COUNTER += 1
|
219 |
+
self.name = name
|
220 |
+
|
221 |
+
@staticmethod
|
222 |
+
def named(namePrefix, initialValue=0.):
|
223 |
+
p = Placeholder(initialValue, namePrefix + str(Placeholder.COUNTER))
|
224 |
+
Placeholder.COUNTER += 1
|
225 |
+
return p
|
226 |
+
|
227 |
+
def __str__(self):
|
228 |
+
return "Placeholder(%s = %s)" % (self.name, self.data)
|
229 |
+
|
230 |
+
@staticmethod
|
231 |
+
def maybe(x):
|
232 |
+
if isinstance(x, DN):
|
233 |
+
return x
|
234 |
+
return Placeholder(float(x))
|
235 |
+
|
236 |
+
def forward(self): return self.data
|
237 |
+
|
238 |
+
def backward(self): return []
|
239 |
+
|
240 |
+
|
241 |
+
class Clamp(DN):
|
242 |
+
def __init__(self, x, l, u):
|
243 |
+
assert u > l
|
244 |
+
self.l = l
|
245 |
+
self.u = u
|
246 |
+
super(Clamp, self).__init__([x])
|
247 |
+
self.name = "clamp"
|
248 |
+
|
249 |
+
def forward(self, x):
|
250 |
+
if x > self.u:
|
251 |
+
return self.u
|
252 |
+
if x < self.l:
|
253 |
+
return self.l
|
254 |
+
return x
|
255 |
+
|
256 |
+
def backward(self, x):
|
257 |
+
if x > self.u or x < self.l:
|
258 |
+
return [0.]
|
259 |
+
else:
|
260 |
+
return [1.]
|
261 |
+
|
262 |
+
|
263 |
+
class Addition(DN):
|
264 |
+
def __init__(self, x, y):
|
265 |
+
super(Addition, self).__init__([x, y])
|
266 |
+
self.name = '+'
|
267 |
+
|
268 |
+
def forward(self, x, y): return x + y
|
269 |
+
|
270 |
+
def backward(self, x, y): return [1., 1.]
|
271 |
+
|
272 |
+
|
273 |
+
class Subtraction(DN):
|
274 |
+
def __init__(self, x, y):
|
275 |
+
super(Subtraction, self).__init__([x, y])
|
276 |
+
self.name = '-'
|
277 |
+
|
278 |
+
def forward(self, x, y): return x - y
|
279 |
+
|
280 |
+
def backward(self, x, y): return [1., -1.]
|
281 |
+
|
282 |
+
|
283 |
+
class Negation(DN):
|
284 |
+
def __init__(self, x):
|
285 |
+
super(Negation, self).__init__([x])
|
286 |
+
self.name = '-'
|
287 |
+
|
288 |
+
def forward(self, x): return -x
|
289 |
+
|
290 |
+
def backward(self, x): return [-1.]
|
291 |
+
|
292 |
+
|
293 |
+
class AbsoluteValue(DN):
|
294 |
+
def __init__(self, x):
|
295 |
+
super(AbsoluteValue, self).__init__([x])
|
296 |
+
self.name = 'abs'
|
297 |
+
|
298 |
+
def forward(self, x): return abs(x)
|
299 |
+
|
300 |
+
def backward(self, x):
|
301 |
+
if x > 0:
|
302 |
+
return [1.]
|
303 |
+
return [-1.]
|
304 |
+
|
305 |
+
|
306 |
+
class Multiplication(DN):
|
307 |
+
def __init__(self, x, y):
|
308 |
+
super(Multiplication, self).__init__([x, y])
|
309 |
+
self.name = '*'
|
310 |
+
|
311 |
+
def forward(self, x, y): return x * y
|
312 |
+
|
313 |
+
def backward(self, x, y): return [y, x]
|
314 |
+
|
315 |
+
|
316 |
+
class Division(DN):
|
317 |
+
def __init__(self, x, y):
|
318 |
+
super(Division, self).__init__([x, y])
|
319 |
+
self.name = '/'
|
320 |
+
|
321 |
+
def forward(self, x, y): return x / y
|
322 |
+
|
323 |
+
def backward(self, x, y): return [1.0 / y, -x / (y * y)]
|
324 |
+
|
325 |
+
|
326 |
+
class Square(DN):
|
327 |
+
def __init__(self, x):
|
328 |
+
super(Square, self).__init__([x])
|
329 |
+
self.name = 'sq'
|
330 |
+
|
331 |
+
def forward(self, x): return x * x
|
332 |
+
|
333 |
+
def backward(self, x): return [2 * x]
|
334 |
+
|
335 |
+
|
336 |
+
class Exponentiation(DN):
|
337 |
+
def __init__(self, x):
|
338 |
+
super(Exponentiation, self).__init__([x])
|
339 |
+
self.name = 'exp'
|
340 |
+
|
341 |
+
def forward(self, x): return math.exp(x)
|
342 |
+
|
343 |
+
def backward(self, x): return [math.exp(x)]
|
344 |
+
|
345 |
+
|
346 |
+
class Logarithm(DN):
|
347 |
+
def __init__(self, x):
|
348 |
+
super(Logarithm, self).__init__([x])
|
349 |
+
self.name = 'log'
|
350 |
+
|
351 |
+
def forward(self, x): return math.log(x)
|
352 |
+
|
353 |
+
def backward(self, x): return [1. / x]
|
354 |
+
|
355 |
+
|
356 |
+
class LSE(DN):
|
357 |
+
def __init__(self, xs):
|
358 |
+
super(LSE, self).__init__(xs)
|
359 |
+
self.name = 'LSE'
|
360 |
+
|
361 |
+
def forward(self, *xs):
|
362 |
+
m = max(xs)
|
363 |
+
return m + math.log(sum(math.exp(y - m) for y in xs))
|
364 |
+
|
365 |
+
def backward(self, *xs):
|
366 |
+
m = max(xs)
|
367 |
+
zm = sum(math.exp(x - m) for x in xs)
|
368 |
+
return [math.exp(x - m) / zm for x in xs]
|
369 |
+
|
370 |
+
|
371 |
+
if __name__ == "__main__":
|
372 |
+
x = Placeholder(10., "x")
|
373 |
+
y = Placeholder(2., "y")
|
374 |
+
z = x - LSE([x, y])
|
375 |
+
z.updateNetwork()
|
376 |
+
eprint("dL/dx = %f\tdL/dy = %f" % (x.derivative, y.derivative))
|
377 |
+
|
378 |
+
x.data = 2.
|
379 |
+
y.data = 10.
|
380 |
+
z.updateNetwork()
|
381 |
+
eprint("dL/dx = %f\tdL/dy = %f" % (x.differentiate(), y.differentiate()))
|
382 |
+
|
383 |
+
x.data = 2.
|
384 |
+
y.data = 2.
|
385 |
+
z.updateNetwork()
|
386 |
+
eprint("z = ", z.data, z)
|
387 |
+
eprint("dL/dx = %f\tdL/dy = %f" % (x.differentiate(), y.differentiate()))
|
388 |
+
|
389 |
+
loss = -z
|
390 |
+
eprint(loss)
|
391 |
+
|
392 |
+
lr = 0.001
|
393 |
+
loss.gradientDescent([x, y], steps=10000, update=1000)
|
dreamcoder/domains/__init__.py
ADDED
File without changes
|
dreamcoder/domains/arithmetic/__init__.py
ADDED
File without changes
|
dreamcoder/domains/arithmetic/arithmeticPrimitives.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.program import *
|
2 |
+
from dreamcoder.type import *
|
3 |
+
|
4 |
+
|
5 |
+
def _addition(x): return lambda y: x + y
|
6 |
+
|
7 |
+
|
8 |
+
def _subtraction(x): return lambda y: x - y
|
9 |
+
|
10 |
+
|
11 |
+
def _division(x): return lambda y: x / y
|
12 |
+
|
13 |
+
|
14 |
+
subtraction = Primitive("-",
|
15 |
+
arrow(tint, arrow(tint, tint)),
|
16 |
+
_subtraction)
|
17 |
+
real_subtraction = Primitive("-.",
|
18 |
+
arrow(treal, treal, treal),
|
19 |
+
_subtraction)
|
20 |
+
addition = Primitive("+",
|
21 |
+
arrow(tint, arrow(tint, tint)),
|
22 |
+
Curried(_addition))
|
23 |
+
real_addition = Primitive("+.",
|
24 |
+
arrow(treal, treal, treal),
|
25 |
+
_addition)
|
26 |
+
|
27 |
+
|
28 |
+
def _multiplication(x): return lambda y: x * y
|
29 |
+
|
30 |
+
|
31 |
+
multiplication = Primitive("*",
|
32 |
+
arrow(tint, arrow(tint, tint)),
|
33 |
+
_multiplication)
|
34 |
+
real_multiplication = Primitive("*.",
|
35 |
+
arrow(treal, treal, treal),
|
36 |
+
_multiplication)
|
37 |
+
real_division = Primitive("/.",
|
38 |
+
arrow(treal, treal, treal),
|
39 |
+
_division)
|
40 |
+
|
41 |
+
|
42 |
+
def _power(a): return lambda b: a**b
|
43 |
+
|
44 |
+
|
45 |
+
real_power = Primitive("power",
|
46 |
+
arrow(treal, treal, treal),
|
47 |
+
_power)
|
48 |
+
|
49 |
+
k1 = Primitive("1", tint, 1)
|
50 |
+
k_negative1 = Primitive("negative_1", tint, -1)
|
51 |
+
k0 = Primitive("0", tint, 0)
|
52 |
+
for n in range(2,10):
|
53 |
+
Primitive(str(n),tint,n)
|
54 |
+
|
55 |
+
f1 = Primitive("1.", treal, 1.)
|
56 |
+
f0 = Primitive("0.", treal, 0)
|
57 |
+
real = Primitive("REAL", treal, None)
|
58 |
+
fpi = Primitive("pi", treal, 3.14)
|
dreamcoder/domains/list/__init__.py
ADDED
File without changes
|
dreamcoder/domains/list/listPrimitives.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.program import Primitive, Program
|
2 |
+
from dreamcoder.grammar import Grammar
|
3 |
+
from dreamcoder.type import tlist, tint, tbool, arrow, t0, t1, t2
|
4 |
+
|
5 |
+
import math
|
6 |
+
from functools import reduce
|
7 |
+
|
8 |
+
|
9 |
+
def _flatten(l): return [x for xs in l for x in xs]
|
10 |
+
|
11 |
+
def _range(n):
|
12 |
+
if n < 100: return list(range(n))
|
13 |
+
raise ValueError()
|
14 |
+
def _if(c): return lambda t: lambda f: t if c else f
|
15 |
+
|
16 |
+
|
17 |
+
def _and(x): return lambda y: x and y
|
18 |
+
|
19 |
+
|
20 |
+
def _or(x): return lambda y: x or y
|
21 |
+
|
22 |
+
|
23 |
+
def _addition(x): return lambda y: x + y
|
24 |
+
|
25 |
+
|
26 |
+
def _subtraction(x): return lambda y: x - y
|
27 |
+
|
28 |
+
|
29 |
+
def _multiplication(x): return lambda y: x * y
|
30 |
+
|
31 |
+
|
32 |
+
def _negate(x): return -x
|
33 |
+
|
34 |
+
|
35 |
+
def _reverse(x): return list(reversed(x))
|
36 |
+
|
37 |
+
|
38 |
+
def _append(x): return lambda y: x + y
|
39 |
+
|
40 |
+
|
41 |
+
def _cons(x): return lambda y: [x] + y
|
42 |
+
|
43 |
+
|
44 |
+
def _car(x): return x[0]
|
45 |
+
|
46 |
+
|
47 |
+
def _cdr(x): return x[1:]
|
48 |
+
|
49 |
+
|
50 |
+
def _isEmpty(x): return x == []
|
51 |
+
|
52 |
+
|
53 |
+
def _single(x): return [x]
|
54 |
+
|
55 |
+
|
56 |
+
def _slice(x): return lambda y: lambda l: l[x:y]
|
57 |
+
|
58 |
+
|
59 |
+
def _map(f): return lambda l: list(map(f, l))
|
60 |
+
|
61 |
+
|
62 |
+
def _zip(a): return lambda b: lambda f: list(map(lambda x,y: f(x)(y), a, b))
|
63 |
+
|
64 |
+
|
65 |
+
def _mapi(f): return lambda l: list(map(lambda i_x: f(i_x[0])(i_x[1]), enumerate(l)))
|
66 |
+
|
67 |
+
|
68 |
+
def _reduce(f): return lambda x0: lambda l: reduce(lambda a, x: f(a)(x), l, x0)
|
69 |
+
|
70 |
+
|
71 |
+
def _reducei(f): return lambda x0: lambda l: reduce(
|
72 |
+
lambda a, t: f(t[0])(a)(t[1]), enumerate(l), x0)
|
73 |
+
|
74 |
+
|
75 |
+
def _fold(l): return lambda x0: lambda f: reduce(
|
76 |
+
lambda a, x: f(x)(a), l[::-1], x0)
|
77 |
+
|
78 |
+
|
79 |
+
def _eq(x): return lambda y: x == y
|
80 |
+
|
81 |
+
|
82 |
+
def _eq0(x): return x == 0
|
83 |
+
|
84 |
+
|
85 |
+
def _a1(x): return x + 1
|
86 |
+
|
87 |
+
|
88 |
+
def _d1(x): return x - 1
|
89 |
+
|
90 |
+
|
91 |
+
def _mod(x): return lambda y: x % y
|
92 |
+
|
93 |
+
|
94 |
+
def _not(x): return not x
|
95 |
+
|
96 |
+
|
97 |
+
def _gt(x): return lambda y: x > y
|
98 |
+
|
99 |
+
|
100 |
+
def _index(j): return lambda l: l[j]
|
101 |
+
|
102 |
+
|
103 |
+
def _replace(f): return lambda lnew: lambda lin: _flatten(
|
104 |
+
lnew if f(i)(x) else [x] for i, x in enumerate(lin))
|
105 |
+
|
106 |
+
|
107 |
+
def _isPrime(n):
|
108 |
+
return n in {
|
109 |
+
2,
|
110 |
+
3,
|
111 |
+
5,
|
112 |
+
7,
|
113 |
+
11,
|
114 |
+
13,
|
115 |
+
17,
|
116 |
+
19,
|
117 |
+
23,
|
118 |
+
29,
|
119 |
+
31,
|
120 |
+
37,
|
121 |
+
41,
|
122 |
+
43,
|
123 |
+
47,
|
124 |
+
53,
|
125 |
+
59,
|
126 |
+
61,
|
127 |
+
67,
|
128 |
+
71,
|
129 |
+
73,
|
130 |
+
79,
|
131 |
+
83,
|
132 |
+
89,
|
133 |
+
97,
|
134 |
+
101,
|
135 |
+
103,
|
136 |
+
107,
|
137 |
+
109,
|
138 |
+
113,
|
139 |
+
127,
|
140 |
+
131,
|
141 |
+
137,
|
142 |
+
139,
|
143 |
+
149,
|
144 |
+
151,
|
145 |
+
157,
|
146 |
+
163,
|
147 |
+
167,
|
148 |
+
173,
|
149 |
+
179,
|
150 |
+
181,
|
151 |
+
191,
|
152 |
+
193,
|
153 |
+
197,
|
154 |
+
199}
|
155 |
+
|
156 |
+
|
157 |
+
def _isSquare(n):
|
158 |
+
return int(math.sqrt(n)) ** 2 == n
|
159 |
+
|
160 |
+
|
161 |
+
def _appendmap(f): lambda xs: [y for x in xs for y in f(x)]
|
162 |
+
|
163 |
+
|
164 |
+
def _filter(f): return lambda l: list(filter(f, l))
|
165 |
+
|
166 |
+
|
167 |
+
def _any(f): return lambda l: any(f(x) for x in l)
|
168 |
+
|
169 |
+
|
170 |
+
def _all(f): return lambda l: all(f(x) for x in l)
|
171 |
+
|
172 |
+
|
173 |
+
def _find(x):
|
174 |
+
def _inner(l):
|
175 |
+
try:
|
176 |
+
return l.index(x)
|
177 |
+
except ValueError:
|
178 |
+
return -1
|
179 |
+
return _inner
|
180 |
+
|
181 |
+
|
182 |
+
def _unfold(x): return lambda p: lambda h: lambda n: __unfold(p, f, n, x)
|
183 |
+
|
184 |
+
|
185 |
+
def __unfold(p, f, n, x, recursion_limit=50):
|
186 |
+
if recursion_limit <= 0:
|
187 |
+
raise RecursionDepthExceeded()
|
188 |
+
if p(x):
|
189 |
+
return []
|
190 |
+
return [f(x)] + __unfold(p, f, n, n(x), recursion_limit - 1)
|
191 |
+
|
192 |
+
|
193 |
+
class RecursionDepthExceeded(Exception):
|
194 |
+
pass
|
195 |
+
|
196 |
+
|
197 |
+
def _fix(argument):
|
198 |
+
def inner(body):
|
199 |
+
recursion_limit = [20]
|
200 |
+
|
201 |
+
def fix(x):
|
202 |
+
def r(z):
|
203 |
+
recursion_limit[0] -= 1
|
204 |
+
if recursion_limit[0] <= 0:
|
205 |
+
raise RecursionDepthExceeded()
|
206 |
+
else:
|
207 |
+
return fix(z)
|
208 |
+
|
209 |
+
return body(r)(x)
|
210 |
+
return fix(argument)
|
211 |
+
|
212 |
+
return inner
|
213 |
+
|
214 |
+
|
215 |
+
def curry(f): return lambda x: lambda y: f((x, y))
|
216 |
+
|
217 |
+
|
218 |
+
def _fix2(a1):
|
219 |
+
return lambda a2: lambda body: \
|
220 |
+
_fix((a1, a2))(lambda r: lambda n_l: body(curry(r))(n_l[0])(n_l[1]))
|
221 |
+
|
222 |
+
|
223 |
+
primitiveRecursion1 = Primitive("fix1",
|
224 |
+
arrow(t0,
|
225 |
+
arrow(arrow(t0, t1), t0, t1),
|
226 |
+
t1),
|
227 |
+
_fix)
|
228 |
+
|
229 |
+
primitiveRecursion2 = Primitive("fix2",
|
230 |
+
arrow(t0, t1,
|
231 |
+
arrow(arrow(t0, t1, t2), t0, t1, t2),
|
232 |
+
t2),
|
233 |
+
_fix2)
|
234 |
+
|
235 |
+
|
236 |
+
def _match(l):
|
237 |
+
return lambda b: lambda f: b if l == [] else f(l[0])(l[1:])
|
238 |
+
|
239 |
+
|
240 |
+
def primitives():
|
241 |
+
return [Primitive(str(j), tint, j) for j in range(6)] + [
|
242 |
+
Primitive("empty", tlist(t0), []),
|
243 |
+
Primitive("singleton", arrow(t0, tlist(t0)), _single),
|
244 |
+
Primitive("range", arrow(tint, tlist(tint)), _range),
|
245 |
+
Primitive("++", arrow(tlist(t0), tlist(t0), tlist(t0)), _append),
|
246 |
+
# Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map),
|
247 |
+
Primitive(
|
248 |
+
"mapi",
|
249 |
+
arrow(
|
250 |
+
arrow(
|
251 |
+
tint,
|
252 |
+
t0,
|
253 |
+
t1),
|
254 |
+
tlist(t0),
|
255 |
+
tlist(t1)),
|
256 |
+
_mapi),
|
257 |
+
# Primitive("reduce", arrow(arrow(t1, t0, t1), t1, tlist(t0), t1), _reduce),
|
258 |
+
Primitive(
|
259 |
+
"reducei",
|
260 |
+
arrow(
|
261 |
+
arrow(
|
262 |
+
tint,
|
263 |
+
t1,
|
264 |
+
t0,
|
265 |
+
t1),
|
266 |
+
t1,
|
267 |
+
tlist(t0),
|
268 |
+
t1),
|
269 |
+
_reducei),
|
270 |
+
|
271 |
+
Primitive("true", tbool, True),
|
272 |
+
Primitive("not", arrow(tbool, tbool), _not),
|
273 |
+
Primitive("and", arrow(tbool, tbool, tbool), _and),
|
274 |
+
Primitive("or", arrow(tbool, tbool, tbool), _or),
|
275 |
+
# Primitive("if", arrow(tbool, t0, t0, t0), _if),
|
276 |
+
|
277 |
+
Primitive("sort", arrow(tlist(tint), tlist(tint)), sorted),
|
278 |
+
Primitive("+", arrow(tint, tint, tint), _addition),
|
279 |
+
Primitive("*", arrow(tint, tint, tint), _multiplication),
|
280 |
+
Primitive("negate", arrow(tint, tint), _negate),
|
281 |
+
Primitive("mod", arrow(tint, tint, tint), _mod),
|
282 |
+
Primitive("eq?", arrow(tint, tint, tbool), _eq),
|
283 |
+
Primitive("gt?", arrow(tint, tint, tbool), _gt),
|
284 |
+
Primitive("is-prime", arrow(tint, tbool), _isPrime),
|
285 |
+
Primitive("is-square", arrow(tint, tbool), _isSquare),
|
286 |
+
|
287 |
+
# these are achievable with above primitives, but unlikely
|
288 |
+
#Primitive("flatten", arrow(tlist(tlist(t0)), tlist(t0)), _flatten),
|
289 |
+
# (lambda (reduce (lambda (lambda (++ $1 $0))) empty $0))
|
290 |
+
Primitive("sum", arrow(tlist(tint), tint), sum),
|
291 |
+
# (lambda (lambda (reduce (lambda (lambda (+ $0 $1))) 0 $0)))
|
292 |
+
Primitive("reverse", arrow(tlist(t0), tlist(t0)), _reverse),
|
293 |
+
# (lambda (reduce (lambda (lambda (++ (singleton $0) $1))) empty $0))
|
294 |
+
Primitive("all", arrow(arrow(t0, tbool), tlist(t0), tbool), _all),
|
295 |
+
# (lambda (lambda (reduce (lambda (lambda (and $0 $1))) true (map $1 $0))))
|
296 |
+
Primitive("any", arrow(arrow(t0, tbool), tlist(t0), tbool), _any),
|
297 |
+
# (lambda (lambda (reduce (lambda (lambda (or $0 $1))) true (map $1 $0))))
|
298 |
+
Primitive("index", arrow(tint, tlist(t0), t0), _index),
|
299 |
+
# (lambda (lambda (reducei (lambda (lambda (lambda (if (eq? $1 $4) $0 0)))) 0 $0)))
|
300 |
+
Primitive("filter", arrow(arrow(t0, tbool), tlist(t0), tlist(t0)), _filter),
|
301 |
+
# (lambda (lambda (reduce (lambda (lambda (++ $1 (if ($3 $0) (singleton $0) empty)))) empty $0)))
|
302 |
+
#Primitive("replace", arrow(arrow(tint, t0, tbool), tlist(t0), tlist(t0), tlist(t0)), _replace),
|
303 |
+
# (FLATTEN (lambda (lambda (lambda (mapi (lambda (lambda (if ($4 $1 $0) $3 (singleton $1)))) $0)))))
|
304 |
+
Primitive("slice", arrow(tint, tint, tlist(t0), tlist(t0)), _slice),
|
305 |
+
# (lambda (lambda (lambda (reducei (lambda (lambda (lambda (++ $2 (if (and (or (gt? $1 $5) (eq? $1 $5)) (not (or (gt? $4 $1) (eq? $1 $4)))) (singleton $0) empty))))) empty $0))))
|
306 |
+
]
|
307 |
+
|
308 |
+
|
309 |
+
def basePrimitives():
|
310 |
+
return [Primitive(str(j), tint, j) for j in range(6)] + [
|
311 |
+
Primitive("*", arrow(tint, tint, tint), _multiplication),
|
312 |
+
Primitive("gt?", arrow(tint, tint, tbool), _gt),
|
313 |
+
Primitive("is-prime", arrow(tint, tbool), _isPrime),
|
314 |
+
Primitive("is-square", arrow(tint, tbool), _isSquare),
|
315 |
+
# McCarthy
|
316 |
+
Primitive("empty", tlist(t0), []),
|
317 |
+
Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
|
318 |
+
Primitive("car", arrow(tlist(t0), t0), _car),
|
319 |
+
Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr),
|
320 |
+
Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty),
|
321 |
+
Primitive("if", arrow(tbool, t0, t0, t0), _if),
|
322 |
+
Primitive("eq?", arrow(tint, tint, tbool), _eq),
|
323 |
+
Primitive("+", arrow(tint, tint, tint), _addition),
|
324 |
+
Primitive("-", arrow(tint, tint, tint), _subtraction)
|
325 |
+
]
|
326 |
+
|
327 |
+
zip_primitive = Primitive("zip", arrow(tlist(t0), tlist(t1), arrow(t0, t1, t2), tlist(t2)), _zip)
|
328 |
+
|
329 |
+
def bootstrapTarget():
|
330 |
+
"""These are the primitives that we hope to learn from the bootstrapping procedure"""
|
331 |
+
return [
|
332 |
+
# learned primitives
|
333 |
+
Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map),
|
334 |
+
Primitive("unfold", arrow(t0, arrow(t0,tbool), arrow(t0,t1), arrow(t0,t0), tlist(t1)), _unfold),
|
335 |
+
Primitive("range", arrow(tint, tlist(tint)), _range),
|
336 |
+
Primitive("index", arrow(tint, tlist(t0), t0), _index),
|
337 |
+
Primitive("fold", arrow(tlist(t0), t1, arrow(t0, t1, t1), t1), _fold),
|
338 |
+
Primitive("length", arrow(tlist(t0), tint), len),
|
339 |
+
|
340 |
+
# built-ins
|
341 |
+
Primitive("if", arrow(tbool, t0, t0, t0), _if),
|
342 |
+
Primitive("+", arrow(tint, tint, tint), _addition),
|
343 |
+
Primitive("-", arrow(tint, tint, tint), _subtraction),
|
344 |
+
Primitive("empty", tlist(t0), []),
|
345 |
+
Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
|
346 |
+
Primitive("car", arrow(tlist(t0), t0), _car),
|
347 |
+
Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr),
|
348 |
+
Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty),
|
349 |
+
] + [Primitive(str(j), tint, j) for j in range(2)]
|
350 |
+
|
351 |
+
|
352 |
+
def bootstrapTarget_extra():
|
353 |
+
"""This is the bootstrap target plus list domain specific stuff"""
|
354 |
+
return bootstrapTarget() + [
|
355 |
+
Primitive("*", arrow(tint, tint, tint), _multiplication),
|
356 |
+
Primitive("mod", arrow(tint, tint, tint), _mod),
|
357 |
+
Primitive("gt?", arrow(tint, tint, tbool), _gt),
|
358 |
+
Primitive("eq?", arrow(tint, tint, tbool), _eq),
|
359 |
+
Primitive("is-prime", arrow(tint, tbool), _isPrime),
|
360 |
+
Primitive("is-square", arrow(tint, tbool), _isSquare),
|
361 |
+
]
|
362 |
+
|
363 |
+
def no_length():
|
364 |
+
"""this is the primitives without length because one of the reviewers wanted this"""
|
365 |
+
return [p for p in bootstrapTarget() if p.name != "length"] + [
|
366 |
+
Primitive("*", arrow(tint, tint, tint), _multiplication),
|
367 |
+
Primitive("mod", arrow(tint, tint, tint), _mod),
|
368 |
+
Primitive("gt?", arrow(tint, tint, tbool), _gt),
|
369 |
+
Primitive("eq?", arrow(tint, tint, tbool), _eq),
|
370 |
+
Primitive("is-prime", arrow(tint, tbool), _isPrime),
|
371 |
+
Primitive("is-square", arrow(tint, tbool), _isSquare),
|
372 |
+
]
|
373 |
+
|
374 |
+
|
375 |
+
def McCarthyPrimitives():
|
376 |
+
"These are < primitives provided by 1959 lisp as introduced by McCarthy"
|
377 |
+
return [
|
378 |
+
Primitive("empty", tlist(t0), []),
|
379 |
+
Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
|
380 |
+
Primitive("car", arrow(tlist(t0), t0), _car),
|
381 |
+
Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr),
|
382 |
+
Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty),
|
383 |
+
#Primitive("unfold", arrow(t0, arrow(t0,t1), arrow(t0,t0), arrow(t0,tbool), tlist(t1)), _isEmpty),
|
384 |
+
#Primitive("1+", arrow(tint,tint),None),
|
385 |
+
# Primitive("range", arrow(tint, tlist(tint)), range),
|
386 |
+
# Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map),
|
387 |
+
# Primitive("index", arrow(tint,tlist(t0),t0),None),
|
388 |
+
# Primitive("length", arrow(tlist(t0),tint),None),
|
389 |
+
primitiveRecursion1,
|
390 |
+
#primitiveRecursion2,
|
391 |
+
Primitive("gt?", arrow(tint, tint, tbool), _gt),
|
392 |
+
Primitive("if", arrow(tbool, t0, t0, t0), _if),
|
393 |
+
Primitive("eq?", arrow(tint, tint, tbool), _eq),
|
394 |
+
Primitive("+", arrow(tint, tint, tint), _addition),
|
395 |
+
Primitive("-", arrow(tint, tint, tint), _subtraction),
|
396 |
+
] + [Primitive(str(j), tint, j) for j in range(2)]
|
397 |
+
|
398 |
+
|
399 |
+
if __name__ == "__main__":
|
400 |
+
bootstrapTarget()
|
401 |
+
g = Grammar.uniform(McCarthyPrimitives())
|
402 |
+
# with open("/home/ellisk/om/ec/experimentOutputs/list_aic=1.0_arity=3_ET=1800_expandFrontier=2.0_it=4_likelihoodModel=all-or-nothing_MF=5_baseline=False_pc=10.0_L=1.0_K=5_rec=False.pickle", "rb") as handle:
|
403 |
+
# b = pickle.load(handle).grammars[-1]
|
404 |
+
# print b
|
405 |
+
|
406 |
+
p = Program.parse(
|
407 |
+
"(lambda (lambda (lambda (if (empty? $0) empty (cons (+ (car $1) (car $0)) ($2 (cdr $1) (cdr $0)))))))")
|
408 |
+
t = arrow(tlist(tint), tlist(tint), tlist(tint)) # ,tlist(tbool))
|
409 |
+
print(g.logLikelihood(arrow(t, t), p))
|
410 |
+
assert False
|
411 |
+
print(b.logLikelihood(arrow(t, t), p))
|
412 |
+
|
413 |
+
# p = Program.parse("""(lambda (lambda
|
414 |
+
# (unfold 0
|
415 |
+
# (lambda (+ (index $0 $2) (index $0 $1)))
|
416 |
+
# (lambda (1+ $0))
|
417 |
+
# (lambda (eq? $0 (length $1))))))
|
418 |
+
# """)
|
419 |
+
p = Program.parse("""(lambda (lambda
|
420 |
+
(map (lambda (+ (index $0 $2) (index $0 $1))) (range (length $0)) )))""")
|
421 |
+
# .replace("unfold", "#(lambda (lambda (lambda (lambda (fix1 $0 (lambda (lambda (#(lambda (lambda (lambda (if $0 empty (cons $1 $2))))) ($1 ($3 $0)) ($4 $0) ($5 $0)))))))))").\
|
422 |
+
# replace("length", "#(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ ($1 (cdr $0)) 1))))))").\
|
423 |
+
# replace("forloop", "(#(lambda (lambda (lambda (lambda (fix1 $0 (lambda (lambda (#(lambda (lambda (lambda (if $0 empty (cons $1 $2))))) ($1 ($3 $0)) ($4 $0) ($5 $0))))))))) (lambda (#(eq? 0) $0)) $0 (lambda (#(lambda (- $0 1)) $0)))").\
|
424 |
+
# replace("inc","#(lambda (+ $0 1))").\
|
425 |
+
# replace("drop","#(lambda (lambda (fix2 $0 $1 (lambda (lambda (lambda (if
|
426 |
+
# (#(eq? 0) $1) $0 (cdr ($2 (- $1 1) $0)))))))))"))
|
427 |
+
print(p)
|
428 |
+
print(g.logLikelihood(t, p))
|
429 |
+
assert False
|
430 |
+
|
431 |
+
print("??")
|
432 |
+
p = Program.parse(
|
433 |
+
"#(lambda (#(lambda (lambda (lambda (fix1 $0 (lambda (lambda (if (empty? $0) $3 ($4 (car $0) ($1 (cdr $0)))))))))) (lambda $1) 1))")
|
434 |
+
for j in range(10):
|
435 |
+
l = list(range(j))
|
436 |
+
print(l, p.evaluate([])(lambda x: x * 2)(l))
|
437 |
+
print()
|
438 |
+
print()
|
439 |
+
|
440 |
+
print("multiply")
|
441 |
+
p = Program.parse(
|
442 |
+
"(lambda (lambda (lambda (if (eq? $0 0) 0 (+ $1 ($2 $1 (- $0 1)))))))")
|
443 |
+
print(g.logLikelihood(arrow(arrow(tint, tint, tint), tint, tint, tint), p))
|
444 |
+
print()
|
445 |
+
|
446 |
+
print("take until 0")
|
447 |
+
p = Program.parse("(lambda (lambda (if (eq? $1 0) empty (cons $1 $0))))")
|
448 |
+
print(g.logLikelihood(arrow(tint, tlist(tint), tlist(tint)), p))
|
449 |
+
print()
|
450 |
+
|
451 |
+
print("countdown primitive")
|
452 |
+
p = Program.parse(
|
453 |
+
"(lambda (lambda (if (eq? $0 0) empty (cons (+ $0 1) ($1 (- $0 1))))))")
|
454 |
+
print(
|
455 |
+
g.logLikelihood(
|
456 |
+
arrow(
|
457 |
+
arrow(
|
458 |
+
tint, tlist(tint)), arrow(
|
459 |
+
tint, tlist(tint))), p))
|
460 |
+
print(_fix(9)(p.evaluate([])))
|
461 |
+
print("countdown w/ better primitives")
|
462 |
+
p = Program.parse(
|
463 |
+
"(lambda (lambda (if (eq0 $0) empty (cons (+1 $0) ($1 (-1 $0))))))")
|
464 |
+
print(
|
465 |
+
g.logLikelihood(
|
466 |
+
arrow(
|
467 |
+
arrow(
|
468 |
+
tint, tlist(tint)), arrow(
|
469 |
+
tint, tlist(tint))), p))
|
470 |
+
|
471 |
+
print()
|
472 |
+
|
473 |
+
print("prepend zeros")
|
474 |
+
p = Program.parse(
|
475 |
+
"(lambda (lambda (lambda (if (eq? $1 0) $0 (cons 0 ($2 (- $1 1) $0))))))")
|
476 |
+
print(
|
477 |
+
g.logLikelihood(
|
478 |
+
arrow(
|
479 |
+
arrow(
|
480 |
+
tint,
|
481 |
+
tlist(tint),
|
482 |
+
tlist(tint)),
|
483 |
+
tint,
|
484 |
+
tlist(tint),
|
485 |
+
tlist(tint)),
|
486 |
+
p))
|
487 |
+
print()
|
488 |
+
assert False
|
489 |
+
|
490 |
+
p = Program.parse(
|
491 |
+
"(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ 1 ($1 (cdr $0))))))))")
|
492 |
+
print(p.evaluate([])(list(range(17))))
|
493 |
+
print(g.logLikelihood(arrow(tlist(tbool), tint), p))
|
494 |
+
|
495 |
+
p = Program.parse(
|
496 |
+
"(lambda (lambda (if (empty? $0) 0 (+ 1 ($1 (cdr $0))))))")
|
497 |
+
print(
|
498 |
+
g.logLikelihood(
|
499 |
+
arrow(
|
500 |
+
arrow(
|
501 |
+
tlist(tbool), tint), arrow(
|
502 |
+
tlist(tbool), tint)), p))
|
503 |
+
|
504 |
+
p = Program.parse(
|
505 |
+
"(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ (car $0) ($1 (cdr $0))))))))")
|
506 |
+
|
507 |
+
print(p.evaluate([])(list(range(4))))
|
508 |
+
print(g.logLikelihood(arrow(tlist(tint), tint), p))
|
509 |
+
|
510 |
+
p = Program.parse(
|
511 |
+
"(lambda (lambda (if (empty? $0) 0 (+ (car $0) ($1 (cdr $0))))))")
|
512 |
+
print(p)
|
513 |
+
print(
|
514 |
+
g.logLikelihood(
|
515 |
+
arrow(
|
516 |
+
arrow(
|
517 |
+
tlist(tint),
|
518 |
+
tint),
|
519 |
+
tlist(tint),
|
520 |
+
tint),
|
521 |
+
p))
|
522 |
+
|
523 |
+
print("take")
|
524 |
+
p = Program.parse(
|
525 |
+
"(lambda (lambda (lambda (if (eq? $1 0) empty (cons (car $0) ($2 (- $1 1) (cdr $0)))))))")
|
526 |
+
print(p)
|
527 |
+
print(
|
528 |
+
g.logLikelihood(
|
529 |
+
arrow(
|
530 |
+
arrow(
|
531 |
+
tint,
|
532 |
+
tlist(tint),
|
533 |
+
tlist(tint)),
|
534 |
+
tint,
|
535 |
+
tlist(tint),
|
536 |
+
tlist(tint)),
|
537 |
+
p))
|
538 |
+
assert False
|
539 |
+
|
540 |
+
print(p.evaluate([])(list(range(4))))
|
541 |
+
print(g.logLikelihood(arrow(tlist(tint), tlist(tint)), p))
|
542 |
+
|
543 |
+
p = Program.parse(
|
544 |
+
"""(lambda (fix (lambda (lambda (match $0 0 (lambda (lambda (+ $1 ($3 $0))))))) $0))""")
|
545 |
+
print(p.evaluate([])(list(range(4))))
|
546 |
+
print(g.logLikelihood(arrow(tlist(tint), tint), p))
|
dreamcoder/domains/list/main.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from collections import defaultdict
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import datetime
|
7 |
+
|
8 |
+
from dreamcoder.dreamcoder import explorationCompression
|
9 |
+
from dreamcoder.utilities import eprint, flatten, testTrainSplit
|
10 |
+
from dreamcoder.grammar import Grammar
|
11 |
+
from dreamcoder.task import Task
|
12 |
+
from dreamcoder.type import Context, arrow, tbool, tlist, tint, t0, UnificationFailure
|
13 |
+
from dreamcoder.domains.list.listPrimitives import basePrimitives, primitives, McCarthyPrimitives, bootstrapTarget_extra, no_length
|
14 |
+
from dreamcoder.domains.list.makeListTasks import make_list_bootstrap_tasks, sortBootstrap, EASYLISTTASKS
|
15 |
+
|
16 |
+
|
17 |
+
def retrieveJSONTasks(filename, features=False):
|
18 |
+
"""
|
19 |
+
For JSON of the form:
|
20 |
+
{"name": str,
|
21 |
+
"type": {"input" : bool|int|list-of-bool|list-of-int,
|
22 |
+
"output": bool|int|list-of-bool|list-of-int},
|
23 |
+
"examples": [{"i": data, "o": data}]}
|
24 |
+
"""
|
25 |
+
with open(filename, "r") as f:
|
26 |
+
loaded = json.load(f)
|
27 |
+
TP = {
|
28 |
+
"bool": tbool,
|
29 |
+
"int": tint,
|
30 |
+
"list-of-bool": tlist(tbool),
|
31 |
+
"list-of-int": tlist(tint),
|
32 |
+
}
|
33 |
+
return [Task(
|
34 |
+
item["name"],
|
35 |
+
arrow(TP[item["type"]["input"]], TP[item["type"]["output"]]),
|
36 |
+
[((ex["i"],), ex["o"]) for ex in item["examples"]],
|
37 |
+
features=(None if not features else list_features(
|
38 |
+
[((ex["i"],), ex["o"]) for ex in item["examples"]])),
|
39 |
+
cache=False,
|
40 |
+
) for item in loaded]
|
41 |
+
|
42 |
+
|
43 |
+
def list_features(examples):
|
44 |
+
if any(isinstance(i, int) for (i,), _ in examples):
|
45 |
+
# obtain features for number inputs as list of numbers
|
46 |
+
examples = [(([i],), o) for (i,), o in examples]
|
47 |
+
elif any(not isinstance(i, list) for (i,), _ in examples):
|
48 |
+
# can't handle non-lists
|
49 |
+
return []
|
50 |
+
elif any(isinstance(x, list) for (xs,), _ in examples for x in xs):
|
51 |
+
# nested lists are hard to extract features for, so we'll
|
52 |
+
# obtain features as if flattened
|
53 |
+
examples = [(([x for xs in ys for x in xs],), o)
|
54 |
+
for (ys,), o in examples]
|
55 |
+
|
56 |
+
# assume all tasks have the same number of examples
|
57 |
+
# and all inputs are lists
|
58 |
+
features = []
|
59 |
+
ot = type(examples[0][1])
|
60 |
+
|
61 |
+
def mean(l): return 0 if not l else sum(l) / len(l)
|
62 |
+
imean = [mean(i) for (i,), o in examples]
|
63 |
+
ivar = [sum((v - imean[idx])**2
|
64 |
+
for v in examples[idx][0][0])
|
65 |
+
for idx in range(len(examples))]
|
66 |
+
|
67 |
+
# DISABLED length of each input and output
|
68 |
+
# total difference between length of input and output
|
69 |
+
# DISABLED normalized count of numbers in input but not in output
|
70 |
+
# total normalized count of numbers in input but not in output
|
71 |
+
# total difference between means of input and output
|
72 |
+
# total difference between variances of input and output
|
73 |
+
# output type (-1=bool, 0=int, 1=list)
|
74 |
+
# DISABLED outputs if integers, else -1s
|
75 |
+
# DISABLED outputs if bools (-1/1), else 0s
|
76 |
+
if ot == list: # lists of ints or bools
|
77 |
+
omean = [mean(o) for (i,), o in examples]
|
78 |
+
ovar = [sum((v - omean[idx])**2
|
79 |
+
for v in examples[idx][1])
|
80 |
+
for idx in range(len(examples))]
|
81 |
+
|
82 |
+
def cntr(
|
83 |
+
l, o): return 0 if not l else len(
|
84 |
+
set(l).difference(
|
85 |
+
set(o))) / len(l)
|
86 |
+
cnt_not_in_output = [cntr(i, o) for (i,), o in examples]
|
87 |
+
|
88 |
+
#features += [len(i) for (i,), o in examples]
|
89 |
+
#features += [len(o) for (i,), o in examples]
|
90 |
+
features.append(sum(len(i) - len(o) for (i,), o in examples))
|
91 |
+
#features += cnt_not_int_output
|
92 |
+
features.append(sum(cnt_not_in_output))
|
93 |
+
features.append(sum(om - im for im, om in zip(imean, omean)))
|
94 |
+
features.append(sum(ov - iv for iv, ov in zip(ivar, ovar)))
|
95 |
+
features.append(1)
|
96 |
+
# features += [-1 for _ in examples]
|
97 |
+
# features += [0 for _ in examples]
|
98 |
+
elif ot == bool:
|
99 |
+
outs = [o for (i,), o in examples]
|
100 |
+
|
101 |
+
#features += [len(i) for (i,), o in examples]
|
102 |
+
#features += [-1 for _ in examples]
|
103 |
+
features.append(sum(len(i) for (i,), o in examples))
|
104 |
+
#features += [0 for _ in examples]
|
105 |
+
features.append(0)
|
106 |
+
features.append(sum(imean))
|
107 |
+
features.append(sum(ivar))
|
108 |
+
features.append(-1)
|
109 |
+
# features += [-1 for _ in examples]
|
110 |
+
# features += [1 if o else -1 for o in outs]
|
111 |
+
else: # int
|
112 |
+
def cntr(
|
113 |
+
l, o): return 0 if not l else len(
|
114 |
+
set(l).difference(
|
115 |
+
set(o))) / len(l)
|
116 |
+
cnt_not_in_output = [cntr(i, [o]) for (i,), o in examples]
|
117 |
+
outs = [o for (i,), o in examples]
|
118 |
+
|
119 |
+
#features += [len(i) for (i,), o in examples]
|
120 |
+
#features += [1 for (i,), o in examples]
|
121 |
+
features.append(sum(len(i) for (i,), o in examples))
|
122 |
+
#features += cnt_not_int_output
|
123 |
+
features.append(sum(cnt_not_in_output))
|
124 |
+
features.append(sum(o - im for im, o in zip(imean, outs)))
|
125 |
+
features.append(sum(ivar))
|
126 |
+
features.append(0)
|
127 |
+
# features += outs
|
128 |
+
# features += [0 for _ in examples]
|
129 |
+
|
130 |
+
return features
|
131 |
+
|
132 |
+
|
133 |
+
def isListFunction(tp):
|
134 |
+
try:
|
135 |
+
Context().unify(tp, arrow(tlist(tint), t0))
|
136 |
+
return True
|
137 |
+
except UnificationFailure:
|
138 |
+
return False
|
139 |
+
|
140 |
+
|
141 |
+
def isIntFunction(tp):
|
142 |
+
try:
|
143 |
+
Context().unify(tp, arrow(tint, t0))
|
144 |
+
return True
|
145 |
+
except UnificationFailure:
|
146 |
+
return False
|
147 |
+
|
148 |
+
try:
|
149 |
+
from dreamcoder.recognition import RecurrentFeatureExtractor
|
150 |
+
class LearnedFeatureExtractor(RecurrentFeatureExtractor):
|
151 |
+
H = 64
|
152 |
+
|
153 |
+
special = None
|
154 |
+
|
155 |
+
def tokenize(self, examples):
|
156 |
+
def sanitize(l): return [z if z in self.lexicon else "?"
|
157 |
+
for z_ in l
|
158 |
+
for z in (z_ if isinstance(z_, list) else [z_])]
|
159 |
+
|
160 |
+
tokenized = []
|
161 |
+
for xs, y in examples:
|
162 |
+
if isinstance(y, list):
|
163 |
+
y = ["LIST_START"] + y + ["LIST_END"]
|
164 |
+
else:
|
165 |
+
y = [y]
|
166 |
+
y = sanitize(y)
|
167 |
+
if len(y) > self.maximumLength:
|
168 |
+
return None
|
169 |
+
|
170 |
+
serializedInputs = []
|
171 |
+
for xi, x in enumerate(xs):
|
172 |
+
if isinstance(x, list):
|
173 |
+
x = ["LIST_START"] + x + ["LIST_END"]
|
174 |
+
else:
|
175 |
+
x = [x]
|
176 |
+
x = sanitize(x)
|
177 |
+
if len(x) > self.maximumLength:
|
178 |
+
return None
|
179 |
+
serializedInputs.append(x)
|
180 |
+
|
181 |
+
tokenized.append((tuple(serializedInputs), y))
|
182 |
+
|
183 |
+
return tokenized
|
184 |
+
|
185 |
+
def __init__(self, tasks, testingTasks=[], cuda=False):
|
186 |
+
self.lexicon = set(flatten((t.examples for t in tasks + testingTasks), abort=lambda x: isinstance(
|
187 |
+
x, str))).union({"LIST_START", "LIST_END", "?"})
|
188 |
+
|
189 |
+
# Calculate the maximum length
|
190 |
+
self.maximumLength = float('inf') # Believe it or not this is actually important to have here
|
191 |
+
self.maximumLength = max(len(l)
|
192 |
+
for t in tasks + testingTasks
|
193 |
+
for xs, y in self.tokenize(t.examples)
|
194 |
+
for l in [y] + [x for x in xs])
|
195 |
+
|
196 |
+
self.recomputeTasks = True
|
197 |
+
|
198 |
+
super(
|
199 |
+
LearnedFeatureExtractor,
|
200 |
+
self).__init__(
|
201 |
+
lexicon=list(
|
202 |
+
self.lexicon),
|
203 |
+
tasks=tasks,
|
204 |
+
cuda=cuda,
|
205 |
+
H=self.H,
|
206 |
+
bidirectional=True)
|
207 |
+
except: pass
|
208 |
+
|
209 |
+
def train_necessary(t):
|
210 |
+
if t.name in {"head", "is-primes", "len", "pop", "repeat-many", "tail", "keep primes", "keep squares"}:
|
211 |
+
return True
|
212 |
+
if any(t.name.startswith(x) for x in {
|
213 |
+
"add-k", "append-k", "bool-identify-geq-k", "count-k", "drop-k",
|
214 |
+
"empty", "evens", "has-k", "index-k", "is-mod-k", "kth-largest",
|
215 |
+
"kth-smallest", "modulo-k", "mult-k", "remove-index-k",
|
216 |
+
"remove-mod-k", "repeat-k", "replace-all-with-index-k", "rotate-k",
|
217 |
+
"slice-k-n", "take-k",
|
218 |
+
}):
|
219 |
+
return "some"
|
220 |
+
return False
|
221 |
+
|
222 |
+
|
223 |
+
def list_options(parser):
|
224 |
+
parser.add_argument(
|
225 |
+
"--noMap", action="store_true", default=False,
|
226 |
+
help="Disable built-in map primitive")
|
227 |
+
parser.add_argument(
|
228 |
+
"--noUnfold", action="store_true", default=False,
|
229 |
+
help="Disable built-in unfold primitive")
|
230 |
+
parser.add_argument(
|
231 |
+
"--noLength", action="store_true", default=False,
|
232 |
+
help="Disable built-in length primitive")
|
233 |
+
parser.add_argument(
|
234 |
+
"--dataset",
|
235 |
+
type=str,
|
236 |
+
default="Lucas-old",
|
237 |
+
choices=[
|
238 |
+
"bootstrap",
|
239 |
+
"sorting",
|
240 |
+
"Lucas-old",
|
241 |
+
"Lucas-depth1",
|
242 |
+
"Lucas-depth2",
|
243 |
+
"Lucas-depth3"])
|
244 |
+
parser.add_argument("--maxTasks", type=int,
|
245 |
+
default=None,
|
246 |
+
help="truncate tasks to fit within this boundary")
|
247 |
+
parser.add_argument("--primitives",
|
248 |
+
default="common",
|
249 |
+
help="Which primitive set to use",
|
250 |
+
choices=["McCarthy", "base", "rich", "common", "noLength"])
|
251 |
+
parser.add_argument("--extractor", type=str,
|
252 |
+
choices=["hand", "deep", "learned"],
|
253 |
+
default="learned")
|
254 |
+
parser.add_argument("--split", metavar="TRAIN_RATIO",
|
255 |
+
type=float,
|
256 |
+
help="split test/train")
|
257 |
+
parser.add_argument("-H", "--hidden", type=int,
|
258 |
+
default=64,
|
259 |
+
help="number of hidden units")
|
260 |
+
parser.add_argument("--random-seed", type=int, default=17)
|
261 |
+
|
262 |
+
|
263 |
+
def main(args):
|
264 |
+
"""
|
265 |
+
Takes the return value of the `commandlineArguments()` function as input and
|
266 |
+
trains/tests the model on manipulating sequences of numbers.
|
267 |
+
"""
|
268 |
+
random.seed(args.pop("random_seed"))
|
269 |
+
|
270 |
+
dataset = args.pop("dataset")
|
271 |
+
tasks = {
|
272 |
+
"Lucas-old": lambda: retrieveJSONTasks("data/list_tasks.json") + sortBootstrap(),
|
273 |
+
"bootstrap": make_list_bootstrap_tasks,
|
274 |
+
"sorting": sortBootstrap,
|
275 |
+
"Lucas-depth1": lambda: retrieveJSONTasks("data/list_tasks2.json")[:105],
|
276 |
+
"Lucas-depth2": lambda: retrieveJSONTasks("data/list_tasks2.json")[:4928],
|
277 |
+
"Lucas-depth3": lambda: retrieveJSONTasks("data/list_tasks2.json"),
|
278 |
+
}[dataset]()
|
279 |
+
|
280 |
+
maxTasks = args.pop("maxTasks")
|
281 |
+
if maxTasks and len(tasks) > maxTasks:
|
282 |
+
necessaryTasks = [] # maxTasks will not consider these
|
283 |
+
if dataset.startswith("Lucas2.0") and dataset != "Lucas2.0-depth1":
|
284 |
+
necessaryTasks = tasks[:105]
|
285 |
+
|
286 |
+
eprint("Unwilling to handle {} tasks, truncating..".format(len(tasks)))
|
287 |
+
random.shuffle(tasks)
|
288 |
+
del tasks[maxTasks:]
|
289 |
+
tasks = necessaryTasks + tasks
|
290 |
+
|
291 |
+
if dataset.startswith("Lucas"):
|
292 |
+
# extra tasks for filter
|
293 |
+
tasks.extend([
|
294 |
+
Task("remove empty lists",
|
295 |
+
arrow(tlist(tlist(tbool)), tlist(tlist(tbool))),
|
296 |
+
[((ls,), list(filter(lambda l: len(l) > 0, ls)))
|
297 |
+
for _ in range(15)
|
298 |
+
for ls in [[[random.random() < 0.5 for _ in range(random.randint(0, 3))]
|
299 |
+
for _ in range(4)]]]),
|
300 |
+
Task("keep squares",
|
301 |
+
arrow(tlist(tint), tlist(tint)),
|
302 |
+
[((xs,), list(filter(lambda x: int(math.sqrt(x)) ** 2 == x,
|
303 |
+
xs)))
|
304 |
+
for _ in range(15)
|
305 |
+
for xs in [[random.choice([0, 1, 4, 9, 16, 25])
|
306 |
+
if random.random() < 0.5
|
307 |
+
else random.randint(0, 9)
|
308 |
+
for _ in range(7)]]]),
|
309 |
+
Task("keep primes",
|
310 |
+
arrow(tlist(tint), tlist(tint)),
|
311 |
+
[((xs,), list(filter(lambda x: x in {2, 3, 5, 7, 11, 13, 17,
|
312 |
+
19, 23, 29, 31, 37}, xs)))
|
313 |
+
for _ in range(15)
|
314 |
+
for xs in [[random.choice([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37])
|
315 |
+
if random.random() < 0.5
|
316 |
+
else random.randint(0, 9)
|
317 |
+
for _ in range(7)]]]),
|
318 |
+
])
|
319 |
+
for i in range(4):
|
320 |
+
tasks.extend([
|
321 |
+
Task("keep eq %s" % i,
|
322 |
+
arrow(tlist(tint), tlist(tint)),
|
323 |
+
[((xs,), list(filter(lambda x: x == i, xs)))
|
324 |
+
for _ in range(15)
|
325 |
+
for xs in [[random.randint(0, 6) for _ in range(5)]]]),
|
326 |
+
Task("remove eq %s" % i,
|
327 |
+
arrow(tlist(tint), tlist(tint)),
|
328 |
+
[((xs,), list(filter(lambda x: x != i, xs)))
|
329 |
+
for _ in range(15)
|
330 |
+
for xs in [[random.randint(0, 6) for _ in range(5)]]]),
|
331 |
+
Task("keep gt %s" % i,
|
332 |
+
arrow(tlist(tint), tlist(tint)),
|
333 |
+
[((xs,), list(filter(lambda x: x > i, xs)))
|
334 |
+
for _ in range(15)
|
335 |
+
for xs in [[random.randint(0, 6) for _ in range(5)]]]),
|
336 |
+
Task("remove gt %s" % i,
|
337 |
+
arrow(tlist(tint), tlist(tint)),
|
338 |
+
[((xs,), list(filter(lambda x: not x > i, xs)))
|
339 |
+
for _ in range(15)
|
340 |
+
for xs in [[random.randint(0, 6) for _ in range(5)]]])
|
341 |
+
])
|
342 |
+
|
343 |
+
def isIdentityTask(t):
|
344 |
+
return all( len(xs) == 1 and xs[0] == y for xs, y in t.examples )
|
345 |
+
eprint("Removed", sum(isIdentityTask(t) for t in tasks), "tasks that were just the identity function")
|
346 |
+
tasks = [t for t in tasks if not isIdentityTask(t) ]
|
347 |
+
|
348 |
+
prims = {"base": basePrimitives,
|
349 |
+
"McCarthy": McCarthyPrimitives,
|
350 |
+
"common": bootstrapTarget_extra,
|
351 |
+
"noLength": no_length,
|
352 |
+
"rich": primitives}[args.pop("primitives")]()
|
353 |
+
haveLength = not args.pop("noLength")
|
354 |
+
haveMap = not args.pop("noMap")
|
355 |
+
haveUnfold = not args.pop("noUnfold")
|
356 |
+
eprint(f"Including map as a primitive? {haveMap}")
|
357 |
+
eprint(f"Including length as a primitive? {haveLength}")
|
358 |
+
eprint(f"Including unfold as a primitive? {haveUnfold}")
|
359 |
+
baseGrammar = Grammar.uniform([p
|
360 |
+
for p in prims
|
361 |
+
if (p.name != "map" or haveMap) and \
|
362 |
+
(p.name != "unfold" or haveUnfold) and \
|
363 |
+
(p.name != "length" or haveLength)])
|
364 |
+
|
365 |
+
extractor = {
|
366 |
+
"learned": LearnedFeatureExtractor,
|
367 |
+
}[args.pop("extractor")]
|
368 |
+
extractor.H = args.pop("hidden")
|
369 |
+
|
370 |
+
timestamp = datetime.datetime.now().isoformat()
|
371 |
+
outputDirectory = "experimentOutputs/list/%s"%timestamp
|
372 |
+
os.system("mkdir -p %s"%outputDirectory)
|
373 |
+
|
374 |
+
args.update({
|
375 |
+
"featureExtractor": extractor,
|
376 |
+
"outputPrefix": "%s/list"%outputDirectory,
|
377 |
+
"evaluationTimeout": 0.0005,
|
378 |
+
})
|
379 |
+
|
380 |
+
|
381 |
+
eprint("Got {} list tasks".format(len(tasks)))
|
382 |
+
split = args.pop("split")
|
383 |
+
if split:
|
384 |
+
train_some = defaultdict(list)
|
385 |
+
for t in tasks:
|
386 |
+
necessary = train_necessary(t)
|
387 |
+
if not necessary:
|
388 |
+
continue
|
389 |
+
if necessary == "some":
|
390 |
+
train_some[t.name.split()[0]].append(t)
|
391 |
+
else:
|
392 |
+
t.mustTrain = True
|
393 |
+
for k in sorted(train_some):
|
394 |
+
ts = train_some[k]
|
395 |
+
random.shuffle(ts)
|
396 |
+
ts.pop().mustTrain = True
|
397 |
+
|
398 |
+
test, train = testTrainSplit(tasks, split)
|
399 |
+
if True:
|
400 |
+
test = [t for t in test
|
401 |
+
if t.name not in EASYLISTTASKS]
|
402 |
+
|
403 |
+
eprint(
|
404 |
+
"Alotted {} tasks for training and {} for testing".format(
|
405 |
+
len(train), len(test)))
|
406 |
+
else:
|
407 |
+
train = tasks
|
408 |
+
test = []
|
409 |
+
|
410 |
+
explorationCompression(baseGrammar, train, testingTasks=test, **args)
|
dreamcoder/domains/list/makeListTasks.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from dreamcoder.type import *
|
4 |
+
from dreamcoder.task import Task
|
5 |
+
from dreamcoder.utilities import eprint, hashable
|
6 |
+
|
7 |
+
from random import randint, random, seed
|
8 |
+
from itertools import product
|
9 |
+
|
10 |
+
# Excluded routines either impossible or astronomically improbable
|
11 |
+
# I'm cutting these off at ~20 nats in learned grammars.
|
12 |
+
EXCLUDES = {
|
13 |
+
"dedup",
|
14 |
+
"intersperse-k",
|
15 |
+
"pow-base-k",
|
16 |
+
"prime",
|
17 |
+
"replace-all-k-with-n",
|
18 |
+
"replace-index-k-with-n",
|
19 |
+
"uniq",
|
20 |
+
}
|
21 |
+
|
22 |
+
# These are tasks that are easy (solved from base DSL) and also uninteresting
|
23 |
+
# We exclude these from the test set
|
24 |
+
EASYLISTTASKS = {
|
25 |
+
"add-k with k=2",
|
26 |
+
"bool-identify-geq-k with k=2",
|
27 |
+
"bool-identify-geq-k with k=3",
|
28 |
+
"bool-identify-is-mod-k with k=1",
|
29 |
+
"bool-identify-is-prime",
|
30 |
+
"bool-identify-k with k=0",
|
31 |
+
"bool-identify-k with k=1",
|
32 |
+
"bool-identify-k with k=2",
|
33 |
+
"caesar-cipher-k-modulo-n with k=3 and n=2",
|
34 |
+
"drop-k with k=1",
|
35 |
+
"drop-k with k=2",
|
36 |
+
"drop-k with k=4",
|
37 |
+
"index-head",
|
38 |
+
"index-k with k=2",
|
39 |
+
"index-k with k=4",
|
40 |
+
"is-mod-k with k=1",
|
41 |
+
"is-odds",
|
42 |
+
"is-squares",
|
43 |
+
"pow-k with k=2",
|
44 |
+
"pow-k with k=3",
|
45 |
+
"prepend-index-k with k=3",
|
46 |
+
"prepend-index-k with k=5",
|
47 |
+
"prepend-k with k=1",
|
48 |
+
"prepend-k with k=2",
|
49 |
+
"prepend-k with k=3",
|
50 |
+
"remove-index-k with k=1",
|
51 |
+
"replace-all-with-index-k with k=2",
|
52 |
+
"replace-all-with-index-k with k=3",
|
53 |
+
"slice-k-n with k=1 and n=2",
|
54 |
+
"slice-k-n with k=2 and n=1",
|
55 |
+
"slice-k-n with k=3 and n=1",
|
56 |
+
}
|
57 |
+
|
58 |
+
def make_list_task(name, examples, **params):
|
59 |
+
input_type = guess_type([i for (i,), _ in examples])
|
60 |
+
output_type = guess_type([o for _, o in examples])
|
61 |
+
|
62 |
+
# We can internally handle lists of bools.
|
63 |
+
# We explicitly create these by modifying existing routines.
|
64 |
+
if name.startswith("identify"):
|
65 |
+
boolexamples = [((i,), list(map(bool, o))) for (i,), o in examples]
|
66 |
+
yield from make_list_task("bool-" + name, boolexamples, **params)
|
67 |
+
# for now, we'll stick with the boolean-only tasks and not have a copy
|
68 |
+
# for integers.
|
69 |
+
return
|
70 |
+
|
71 |
+
program_type = arrow(input_type, output_type)
|
72 |
+
cache = all(hashable(x) for x in examples)
|
73 |
+
|
74 |
+
if params:
|
75 |
+
eq_params = ["{}={}".format(k, v) for k, v in params.items()]
|
76 |
+
if len(eq_params) == 1:
|
77 |
+
ext = eq_params[0]
|
78 |
+
elif len(eq_params) == 2:
|
79 |
+
ext = "{} and {}".format(*eq_params)
|
80 |
+
else:
|
81 |
+
ext = ", ".join(eq_params[:-1])
|
82 |
+
ext = "{}, and {}".format(ext, eq_params[-1])
|
83 |
+
name += " with " + ext
|
84 |
+
|
85 |
+
yield Task(name, program_type, examples, cache=cache)
|
86 |
+
|
87 |
+
|
88 |
+
def make_list_tasks(n_examples):
|
89 |
+
import listroutines as lr
|
90 |
+
|
91 |
+
for routine in lr.find(count=100): # all routines
|
92 |
+
if routine.id in EXCLUDES:
|
93 |
+
continue
|
94 |
+
if routine.is_parametric():
|
95 |
+
keys = list(routine.example_params()[0].keys())
|
96 |
+
for params in map(lambda values: dict(zip(keys, values)),
|
97 |
+
product(range(6), repeat=len(keys))):
|
98 |
+
try:
|
99 |
+
if routine.id == "rotate-k":
|
100 |
+
# rotate-k is hard if list is smaller than k
|
101 |
+
k = params["k"]
|
102 |
+
if k < 1:
|
103 |
+
continue
|
104 |
+
inps = []
|
105 |
+
for _ in range(n_examples):
|
106 |
+
r = randint(abs(k) + 1, 17)
|
107 |
+
inp = routine.gen(len=r, **params)[0]
|
108 |
+
inps.append(inp)
|
109 |
+
else:
|
110 |
+
inps = routine.gen(count=n_examples, **params)
|
111 |
+
examples = [((inp,), routine.eval(inp, **params))
|
112 |
+
for inp in inps]
|
113 |
+
yield from make_list_task(routine.id, examples, **params)
|
114 |
+
except lr.APIError: # invalid params
|
115 |
+
continue
|
116 |
+
else:
|
117 |
+
inps = routine.examples()
|
118 |
+
if len(inps) > n_examples:
|
119 |
+
inps = inps[:n_examples]
|
120 |
+
elif len(inps) < n_examples:
|
121 |
+
inps += routine.gen(count=(n_examples - len(inps)))
|
122 |
+
examples = [((inp,), routine.eval(inp)) for inp in inps]
|
123 |
+
yield from make_list_task(routine.id, examples)
|
124 |
+
|
125 |
+
|
126 |
+
def make_list_bootstrap_tasks():
|
127 |
+
seed(42)
|
128 |
+
|
129 |
+
def suffixes(l):
|
130 |
+
if l == []:
|
131 |
+
return []
|
132 |
+
else:
|
133 |
+
return [l[1:]] + suffixes(l[1:])
|
134 |
+
|
135 |
+
def flip(): return random() > 0.5
|
136 |
+
|
137 |
+
def randomSuffix():
|
138 |
+
return [randint(0, 9) for _ in range(randint(1, 4))]
|
139 |
+
|
140 |
+
def randomList(minimum=0, minimumLength=4, maximumLength=6):
|
141 |
+
return [randint(minimum, 9) for _ in range(randint(minimumLength, maximumLength))]
|
142 |
+
|
143 |
+
def randomListOfLists():
|
144 |
+
return [randomSuffix() for _ in range(randint(2, 4))]
|
145 |
+
|
146 |
+
def randomListOfLists_bool(l=None):
|
147 |
+
if l is None:
|
148 |
+
l = randint(4, 7)
|
149 |
+
return [randomBooleanList() for _ in range(l)]
|
150 |
+
|
151 |
+
def randomBooleanList():
|
152 |
+
return [flip() for _ in range(randint(4, 7))]
|
153 |
+
|
154 |
+
# Reliably learned in under a minute; always triggers learning of length
|
155 |
+
# primitive
|
156 |
+
lengthBootstrap = [
|
157 |
+
# Task("length bool", arrow(tlist(tbool), tint),
|
158 |
+
# [((l,), len(l))
|
159 |
+
# for _ in range(10)
|
160 |
+
# for l in [[flip() for _ in range(randint(0, 10))]]]),
|
161 |
+
Task("length int", arrow(tlist(tint), tint),
|
162 |
+
[((l,), len(l))
|
163 |
+
for _ in range(10)
|
164 |
+
for l in [randomList()]]),
|
165 |
+
Task("map length", arrow(tlist(tlist(tint)), tlist(tint)),
|
166 |
+
[((xss,), [len(xs) for xs in xss])
|
167 |
+
for _ in range(10)
|
168 |
+
for xss in [randomListOfLists()] ])
|
169 |
+
]
|
170 |
+
|
171 |
+
# Encourages learning of unfolding
|
172 |
+
unfoldBootstrap = [
|
173 |
+
Task("countdown", arrow(tint, tlist(tint)),
|
174 |
+
[((n,), list(range(n + 1, 1, -1)))
|
175 |
+
for n in range(10)]),
|
176 |
+
Task("weird count", arrow(tint, tlist(tint)),
|
177 |
+
[((n,), list(range(-n,0,-1)))
|
178 |
+
for n in range(-10,0) ]),
|
179 |
+
Task("take every other", arrow(tlist(tint),tlist(tint)),
|
180 |
+
[((l,), [x for j,x in enumerate(l) if j%2 == 0])
|
181 |
+
for _ in range(9)
|
182 |
+
for l in [ [randint(0, 9) for _ in range(randint(1,4)*2)] ] ] + [(([],),[])]),
|
183 |
+
# Task("stutter every other", arrow(tlist(tint),tlist(tint)),
|
184 |
+
# [((l,), [l[int(j/2)] for j in range(len(l)) ])
|
185 |
+
# for _ in range(10)
|
186 |
+
# for l in [ [randint(0, 9) for _ in range(randint(1,4)*2)] ] ]),
|
187 |
+
# Task("take until 3 reached", arrow(tlist(tint),tlist(tint)),
|
188 |
+
# [((p + [3] + s,),p)
|
189 |
+
# for _ in range(10)
|
190 |
+
# for p in [ [z for z in randomList()[:5] if z != 3 ]]
|
191 |
+
# for s in [randomList()] ]),
|
192 |
+
Task("drop last element", arrow(tlist(tint),tlist(tint)),
|
193 |
+
[((l,), l[:-1])
|
194 |
+
for _ in range(10)
|
195 |
+
for l in [ [randint(0, 9) for _ in range(randint(2,5))] ] ]),
|
196 |
+
# Task("suffixes", arrow(tlist(tint), tlist(tlist(tint))),
|
197 |
+
# [((l,), suffixes(l))
|
198 |
+
# for _ in range(10)
|
199 |
+
# for l in [randomList()]]),
|
200 |
+
Task("range", arrow(tint, tlist(tint)),
|
201 |
+
[((n,), list(range(n)))
|
202 |
+
for n in range(10)]),
|
203 |
+
Task("range inclusive", arrow(tint, tlist(tint)),
|
204 |
+
[((n,), list(range(n + 1)))
|
205 |
+
for n in range(10)]),
|
206 |
+
# Task("range inclusive+1", arrow(tint, tlist(tint)),
|
207 |
+
# [((n,), list(range(n + 2)))
|
208 |
+
# for n in range(10)]),
|
209 |
+
# Task("range exclusive", arrow(tint, tlist(tint)),
|
210 |
+
# [((n,), list(range(n - 1)))
|
211 |
+
# for n in range(2, 11)]),
|
212 |
+
# Task("range length", arrow(tlist(tint),tlist(tint)),
|
213 |
+
# [((l,),list(range(len(l))))
|
214 |
+
# for _ in range(10)
|
215 |
+
# for l in [randomList()] ])
|
216 |
+
]
|
217 |
+
|
218 |
+
# Encourages learning how to treat a list as an array
|
219 |
+
arrayBootstrap = [
|
220 |
+
Task("index int", arrow(tint, tlist(tint), tint),
|
221 |
+
[((n, l), l[n])
|
222 |
+
for n in range(10)
|
223 |
+
for l in [[randint(0, 9) for _ in range(randint(n + 1, n + 5))]]]),
|
224 |
+
# Task("last n", arrow(tint, tlist(tint), tlist(tint)),
|
225 |
+
# [((n, l), l[-n:])
|
226 |
+
# for n in range(10)
|
227 |
+
# for l in [[randint(0, 9) for _ in range(randint(n + 1, n + 5))]]]),
|
228 |
+
Task("1-index int", arrow(tint, tlist(tint), tint),
|
229 |
+
[((n, l), l[n - 1])
|
230 |
+
for n in range(1,11)
|
231 |
+
for l in [[randint(0, 9) for _ in range(randint(n + 1, n + 4))]]])
|
232 |
+
|
233 |
+
# Task("index bool", arrow(tint, tlist(tbool), tbool),
|
234 |
+
# [((n, l), l[n])
|
235 |
+
# for n in range(10)
|
236 |
+
# for l in [[flip() for _ in range(randint(n + 1, n + 5))]]])
|
237 |
+
]
|
238 |
+
|
239 |
+
# Teaches how to slice lists, not sure if we really need this though
|
240 |
+
sliceBootstrap = [
|
241 |
+
Task("take bool", arrow(tint, tlist(tbool), tlist(tbool)),
|
242 |
+
[((n, l), l[:n])
|
243 |
+
for n in range(10)
|
244 |
+
for l in [[flip() for _ in range(randint(n, n + 5))]]]),
|
245 |
+
Task("drop bool", arrow(tint, tlist(tbool), tlist(tbool)),
|
246 |
+
[((n, l), l[n:])
|
247 |
+
for n in range(10)
|
248 |
+
for l in [[flip() for _ in range(randint(n, n + 5))]]]),
|
249 |
+
|
250 |
+
Task("take int", arrow(tint, tlist(tint), tlist(tint)),
|
251 |
+
[((n, l), l[:n])
|
252 |
+
for n in range(10)
|
253 |
+
for l in [[randint(0, 9) for _ in range(randint(n, n + 5))]]]),
|
254 |
+
Task("drop int", arrow(tint, tlist(tint), tlist(tint)),
|
255 |
+
[((n, l), l[n:])
|
256 |
+
for n in range(10)
|
257 |
+
for l in [[randint(0, 9) for _ in range(randint(n, n + 5))]]]),
|
258 |
+
|
259 |
+
]
|
260 |
+
|
261 |
+
# learning to fold
|
262 |
+
foldBootstrap = [
|
263 |
+
Task("stutter", arrow(tlist(tint),tlist(tint)),
|
264 |
+
[((l,), [z for x in l for z in [x,x] ])
|
265 |
+
for _ in range(10)
|
266 |
+
for l in [randomList()] ]),
|
267 |
+
Task("sum", arrow(tlist(tint), tint),
|
268 |
+
[((l,), sum(l))
|
269 |
+
for _ in range(10)
|
270 |
+
for l in [randomList()]]),
|
271 |
+
# Task("difference", arrow(tlist(tint), tint),
|
272 |
+
# [((l,), reduce(lambda x, y: y - x, reversed(l), 1))
|
273 |
+
# for _ in range(10)
|
274 |
+
# for l in [randomList()[:4]]]),
|
275 |
+
# Task("append bool", arrow(tlist(tbool), tlist(tbool), tlist(tbool)),
|
276 |
+
# [((x, y), x + y)
|
277 |
+
# for _ in range(10)
|
278 |
+
# for [x, y] in [[randomBooleanList(), randomBooleanList()]]]),
|
279 |
+
Task("append constant 0", arrow(tlist(tint),tlist(tint)),
|
280 |
+
[((l,),l + [0])
|
281 |
+
for _ in range(10)
|
282 |
+
for l in [randomList()] ]),
|
283 |
+
]
|
284 |
+
|
285 |
+
# learning to map
|
286 |
+
mapBootstrap = [
|
287 |
+
Task("map double", arrow(tlist(tint), tlist(tint)),
|
288 |
+
[((l,), list(map(lambda n: n * 2, l)))
|
289 |
+
for _ in range(10)
|
290 |
+
for l in [randomList()]]),
|
291 |
+
Task("map increment", arrow(tlist(tint),tlist(tint)),
|
292 |
+
[((l,),list(map(lambda n: n+1, l)))
|
293 |
+
for _ in range(10)
|
294 |
+
for l in [randomList()] ]),
|
295 |
+
Task("map negation", arrow(tlist(tint),tlist(tint)),
|
296 |
+
[((l,),list(map(lambda n: 0-n, l)))
|
297 |
+
for _ in range(10)
|
298 |
+
for l in [randomList()] ]),
|
299 |
+
# Task("map car", arrow(tlist(tlist(tint)), tlist(tint)),
|
300 |
+
# [((l,), [n[0] for n in l])
|
301 |
+
# for _ in4 range(10)
|
302 |
+
# for l in [randomListOfLists()]]),
|
303 |
+
# Task("map cdr", arrow(tlist(tlist(tbool)),tlist(tlist(tbool))),
|
304 |
+
# [((l,),map(lambda n: n[1:],l))
|
305 |
+
# for _ in range(10)
|
306 |
+
# for l in [randomListOfLists_bool()]]),
|
307 |
+
# Task("map empty?", arrow(tlist(tlist(tint)), tlist(tboolean)),
|
308 |
+
# [((l,), [n == [] for n in l])
|
309 |
+
# for _ in range(10)
|
310 |
+
# for l in [[[] if flip() else randomList() for _ in range(randint(1, 5))]]]),
|
311 |
+
|
312 |
+
# Task("map eq 0?", arrow(tlist(tint),tlist(tboolean)),
|
313 |
+
# [((l,),map(lambda n: 0 == n,l))
|
314 |
+
# for _ in range(10)
|
315 |
+
# for l in [[ randint(0,3) for _ in range(randint(4,7)) ]] ])
|
316 |
+
|
317 |
+
]
|
318 |
+
difficultMaps = [
|
319 |
+
Task("map quadruple", arrow(tlist(tint), tlist(tint)),
|
320 |
+
[((l,), list(map(lambda n: n * 4, l)))
|
321 |
+
for _ in range(10)
|
322 |
+
for l in [randomList()]]),
|
323 |
+
Task("map add 3", arrow(tlist(tint),tlist(tint)),
|
324 |
+
[((l,),list(map(lambda n: n+3, l)))
|
325 |
+
for _ in range(10)
|
326 |
+
for l in [randomList()] ]),
|
327 |
+
|
328 |
+
]
|
329 |
+
|
330 |
+
# Learning to zip lists together
|
331 |
+
zipBootstrap = [
|
332 |
+
Task("zip plus", arrow(tlist(tint),tlist(tint),tlist(tint)),
|
333 |
+
[((l1,l2),list(map(lambda x,y: x+y,l1,l2)))
|
334 |
+
for _ in range(10)
|
335 |
+
for l1 in [randomList(minimumLength=2, maximumLength=4)]
|
336 |
+
for l2 in [[ randint(0,9) for _ in range(len(l1)) ]]]),
|
337 |
+
Task("zip minus", arrow(tlist(tint),tlist(tint),tlist(tint)),
|
338 |
+
[((l1,l2),list(map(lambda x,y: x-y,l1,l2)))
|
339 |
+
for _ in range(10)
|
340 |
+
for l1 in [randomList(minimumLength=2, maximumLength=4)]
|
341 |
+
for l2 in [[ randint(0,9) for _ in range(len(l1)) ]]]),
|
342 |
+
# Task("zip eq?", arrow(tlist(tint), tlist(tint), tlist(tbool)),
|
343 |
+
# [((l1, l2), list(map(lambda x, y: x == y, l1, l2)))
|
344 |
+
# for _ in range(10)
|
345 |
+
# for l1 in [[randint(0, 3) for _ in range(randint(4, 7))]]
|
346 |
+
# for l2 in [[randint(0, 3) for _ in range(len(l1))]]]),
|
347 |
+
# Task("zip cons", arrow(tlist(tbool), tlist(tlist(tbool)), tlist(tlist(tbool))),
|
348 |
+
# [((l1, l2), list(map(lambda x, y: [x] + y, l1, l2)))
|
349 |
+
# for _ in range(10)
|
350 |
+
# for l1 in [randomBooleanList()]
|
351 |
+
# for l2 in [randomListOfLists_bool(l=len(l1))]]),
|
352 |
+
# Task("zip cons", arrow(tlist(tint),tlist(tlist(tint)),tlist(tlist(tint))),
|
353 |
+
# [((l1,l2),list(map(lambda x,y: [x]+y,l1,l2)))
|
354 |
+
# for _ in range(10)
|
355 |
+
# for l1 in [randomList()]
|
356 |
+
# for l2 in [[ randomList() for _ in range(len(l1)) ]]]),
|
357 |
+
]
|
358 |
+
|
359 |
+
# Learning to filter
|
360 |
+
filterBootstrap = [
|
361 |
+
# Task("remove empty lists",
|
362 |
+
# arrow(tlist(tlist(tbool)), tlist(tlist(tbool))),
|
363 |
+
# [((ls,), [l for l in ls if len(l) > 0])
|
364 |
+
# for _ in range(10)
|
365 |
+
# for ls in [[[flip() for _ in range(randint(0, 3))]
|
366 |
+
# for _ in range(4)]]])
|
367 |
+
# Task("remove non 0s",
|
368 |
+
# arrow(tlist(tint), tlist(tint)),
|
369 |
+
# [((xs,), filter(lambda x: x == 0, xs))
|
370 |
+
# for _ in range(10)
|
371 |
+
# for xs in [[ randint(0,3) for _ in range(5) ]] ]),
|
372 |
+
Task("remove 0s",
|
373 |
+
arrow(tlist(tint), tlist(tint)),
|
374 |
+
[((xs,), [x for x in xs if x != 0])
|
375 |
+
for _ in range(10)
|
376 |
+
for xs in [[randint(0, 3) for _ in range(5)]]]),
|
377 |
+
Task("remove non-positives",
|
378 |
+
arrow(tlist(tint), tlist(tint)),
|
379 |
+
[((xs,), [x for x in xs if not (x > 1)])
|
380 |
+
for _ in range(10)
|
381 |
+
for xs in [[randint(0, 3) for _ in range(5)]]]),
|
382 |
+
]
|
383 |
+
|
384 |
+
return lengthBootstrap + filterBootstrap + \
|
385 |
+
unfoldBootstrap + arrayBootstrap + foldBootstrap + mapBootstrap + zipBootstrap
|
386 |
+
|
387 |
+
|
388 |
+
def bonusListProblems():
|
389 |
+
# Taken from https://www.ijcai.org/Proceedings/75/Papers/037.pdf
|
390 |
+
# These problems might be a lot easier if we do not use numbers
|
391 |
+
def randomList(lb=None, ub=None):
|
392 |
+
if lb is None:
|
393 |
+
lb = 2
|
394 |
+
if ub is None:
|
395 |
+
ub = 5
|
396 |
+
return [randint(0, 5) for _ in range(randint(lb, ub))]
|
397 |
+
|
398 |
+
bonus = [
|
399 |
+
Task(
|
400 |
+
"pair reverse", arrow(tlist(tint), tlist(tint)),
|
401 |
+
[((x,), [x[j + (1 if j % 2 == 0 else -1)]
|
402 |
+
for j in range(len(x))])
|
403 |
+
for _ in range(5)
|
404 |
+
for x in [randomList(10, 10)]]
|
405 |
+
),
|
406 |
+
Task(
|
407 |
+
"duplicate each element", arrow(tlist(tint), tlist(tint)),
|
408 |
+
[((x,), [a for z in x for a in [z] * 2])
|
409 |
+
for _ in range(5)
|
410 |
+
for x in [randomList(4, 6)]]
|
411 |
+
),
|
412 |
+
Task(
|
413 |
+
"reverse duplicate each element", arrow(tlist(tint), tlist(tint)),
|
414 |
+
[((x,), [a for z in reversed(x) for a in [z] * 2])]
|
415 |
+
),
|
416 |
+
]
|
417 |
+
return bonus
|
418 |
+
|
419 |
+
def sortBootstrap():
|
420 |
+
# These tasks have as their goal the learning of (1) filter, and
|
421 |
+
# (2) sort, which uses filter.
|
422 |
+
def flip(): return random() > 0.5
|
423 |
+
def randomList(lb=None, ub=None):
|
424 |
+
if lb is None:
|
425 |
+
lb = 2
|
426 |
+
if ub is None:
|
427 |
+
ub = 5
|
428 |
+
return [randint(0, 10) for _ in range(randint(lb, ub))]
|
429 |
+
def randomBooleanList():
|
430 |
+
return [flip() for _ in range(randint(4, 7))]
|
431 |
+
def removeDuplicates(l):
|
432 |
+
if len(l) == 0: return l
|
433 |
+
return [l[0]] + removeDuplicates([ z for z in l if z != l[0] ])
|
434 |
+
|
435 |
+
filterBootstrap = [
|
436 |
+
# Task("remove empty lists",
|
437 |
+
# arrow(tlist(tlist(tbool)), tlist(tlist(tbool))),
|
438 |
+
# [((ls,), [l for l in ls if len(l) > 0])
|
439 |
+
# for _ in range(10)
|
440 |
+
# for ls in [[[flip() for _ in range(randint(0, 3))]
|
441 |
+
# for _ in range(4)]]]),
|
442 |
+
# Task("remove non 0s",
|
443 |
+
# arrow(tlist(tint), tlist(tint)),
|
444 |
+
# [((xs,), filter(lambda x: x == 0, xs))
|
445 |
+
# for _ in range(10)
|
446 |
+
# for xs in [[ randint(0,3) for _ in range(5) ]] ]),
|
447 |
+
Task("remove 0s",
|
448 |
+
arrow(tlist(tint), tlist(tint)),
|
449 |
+
[((xs,), [x for x in xs if x != 0])
|
450 |
+
for _ in range(10)
|
451 |
+
for xs in [[randint(0, 3) for _ in range(5)]]]),
|
452 |
+
# Task("remove primes",
|
453 |
+
# arrow(tlist(tint), tlist(tint)),
|
454 |
+
# [((xs,), [x for x in xs if not (x in {2,3,5,7,11,13,17,19,23})])
|
455 |
+
# for _ in range(10)
|
456 |
+
# for xs in [[randint(0, 20) for _ in range(7)]]]),
|
457 |
+
Task("remove squares",
|
458 |
+
arrow(tlist(tint), tlist(tint)),
|
459 |
+
[((xs,), [x for x in xs if not (int(x**0.5)**2 == x)])
|
460 |
+
for _ in range(10)
|
461 |
+
for xs in [[randint(0, 20) for _ in range(7)]]]),
|
462 |
+
Task("remove > 1",
|
463 |
+
arrow(tlist(tint), tlist(tint)),
|
464 |
+
[((xs,), [x for x in xs if not (x > 1)])
|
465 |
+
for _ in range(10)
|
466 |
+
for xs in [[randint(0, 5) for _ in range(7)]]]),
|
467 |
+
]
|
468 |
+
|
469 |
+
# Needed for selection sort
|
470 |
+
minimumBootstrap = [
|
471 |
+
Task("min2", arrow(tint,tint,tint),
|
472 |
+
[((x,y),min(x,y))
|
473 |
+
for x in range(4)
|
474 |
+
for y in range(4) ]),
|
475 |
+
Task("minimum of list", arrow(tlist(tint),tint),
|
476 |
+
[((l,),min(l))
|
477 |
+
for _ in range(15)
|
478 |
+
for l in [randomList()] ])
|
479 |
+
]
|
480 |
+
|
481 |
+
appendBootstrap = [
|
482 |
+
Task("append bool", arrow(tlist(tbool), tlist(tbool), tlist(tbool)),
|
483 |
+
[((x, y), x + y)
|
484 |
+
for _ in range(10)
|
485 |
+
for [x, y] in [[randomBooleanList(), randomBooleanList()]]]),
|
486 |
+
Task("append int", arrow(tlist(tint), tlist(tint), tlist(tint)),
|
487 |
+
[((x, y), x + y)
|
488 |
+
for _ in range(10)
|
489 |
+
for [x, y] in [[randomList(), randomList()]]])
|
490 |
+
]
|
491 |
+
|
492 |
+
insertionBootstrap = [
|
493 |
+
Task("filter greater than or equal", arrow(tint,tlist(tint),tlist(tint)),
|
494 |
+
[((x,l), [y for y in l if y >= x ])
|
495 |
+
for _ in range(15)
|
496 |
+
for x in [randint(0,5)]
|
497 |
+
for l in [randomList()] ]),
|
498 |
+
Task("filter less than", arrow(tint,tlist(tint),tlist(tint)),
|
499 |
+
[((x,l), [y for y in l if y < x ])
|
500 |
+
for _ in range(15)
|
501 |
+
for x in [randint(0,5)]
|
502 |
+
for l in [randomList()] ]),
|
503 |
+
Task("insert into sorted list (I)", arrow(tint,tlist(tint),tlist(tint)),
|
504 |
+
[((x,l), [y for y in l if y < x ] + [x] + [y for y in l if y >= x ])
|
505 |
+
for _ in range(15)
|
506 |
+
for x in [randint(0,5)]
|
507 |
+
for _l in [randomList()]
|
508 |
+
for l in [sorted(_l)] ]),
|
509 |
+
Task("insert into sorted list (II)", arrow(tint,tlist(tint),tlist(tint)),
|
510 |
+
[((x,l), [y for y in l if y < x ] + [x] + [y for y in l if y >= x ])
|
511 |
+
for _ in range(15)
|
512 |
+
for x in [randint(0,5)]
|
513 |
+
for l in [randomList()] ])
|
514 |
+
]
|
515 |
+
|
516 |
+
|
517 |
+
sortTask = [
|
518 |
+
Task("sort-and-deduplicate", arrow(tlist(tint),tlist(tint)),
|
519 |
+
[((l,),list(sorted(l)))
|
520 |
+
for _ in range(15)
|
521 |
+
for l in [removeDuplicates(randomList())]
|
522 |
+
])]
|
523 |
+
|
524 |
+
slowSort = [
|
525 |
+
Task("+1 maximum list", arrow(tlist(tint), tint),
|
526 |
+
[((l,),max(l) + 1)
|
527 |
+
for _ in range(15)
|
528 |
+
for l in [randomList()] ]),
|
529 |
+
Task("range +1 maximum list", arrow(tlist(tint), tlist(tint)),
|
530 |
+
[((l,),list(range(max(l) + 1)))
|
531 |
+
for _ in range(15)
|
532 |
+
for l in [randomList()] ]),
|
533 |
+
]
|
534 |
+
|
535 |
+
|
536 |
+
tasks = sortTask + slowSort
|
537 |
+
for t in tasks: t.mustTrain = True
|
538 |
+
return tasks
|
539 |
+
|
540 |
+
|
541 |
+
def exportTasks():
|
542 |
+
import sys
|
543 |
+
import pickle as pickle
|
544 |
+
|
545 |
+
n_examples = 15
|
546 |
+
if len(sys.argv) > 1:
|
547 |
+
n_examples = int(sys.argv[1])
|
548 |
+
|
549 |
+
eprint("Downloading and generating dataset")
|
550 |
+
tasks = sorted(make_list_tasks(n_examples), key=lambda t: t.name)
|
551 |
+
eprint("Got {} list tasks".format(len(tasks)))
|
552 |
+
|
553 |
+
with open("data/list_tasks.pkl", "w") as f:
|
554 |
+
pickle.dump(tasks, f)
|
555 |
+
eprint("Wrote list tasks to data/list_tasks.pkl")
|
556 |
+
|
557 |
+
|
558 |
+
if __name__ == "__main__":
|
559 |
+
import json
|
560 |
+
def retrieveJSONTasks(filename, features=False):
|
561 |
+
"""
|
562 |
+
For JSON of the form:
|
563 |
+
{"name": str,
|
564 |
+
"type": {"input" : bool|int|list-of-bool|list-of-int,
|
565 |
+
"output": bool|int|list-of-bool|list-of-int},
|
566 |
+
"examples": [{"i": data, "o": data}]}
|
567 |
+
"""
|
568 |
+
with open(filename, "r") as f:
|
569 |
+
loaded = json.load(f)
|
570 |
+
TP = {
|
571 |
+
"bool": tbool,
|
572 |
+
"int": tint,
|
573 |
+
"list-of-bool": tlist(tbool),
|
574 |
+
"list-of-int": tlist(tint),
|
575 |
+
}
|
576 |
+
return [Task(
|
577 |
+
item["name"],
|
578 |
+
arrow(TP[item["type"]["input"]], TP[item["type"]["output"]]),
|
579 |
+
[((ex["i"],), ex["o"]) for ex in item["examples"]],
|
580 |
+
features=(None if not features else list_features(
|
581 |
+
[((ex["i"],), ex["o"]) for ex in item["examples"]])),
|
582 |
+
cache=False,
|
583 |
+
) for item in loaded]
|
584 |
+
for t in retrieveJSONTasks("data/list_tasks.json") + sortBootstrap() + make_list_bootstrap_tasks():
|
585 |
+
print(t.describe())
|
586 |
+
print()
|
587 |
+
# exportTasks()
|
dreamcoder/domains/logo/__init__.py
ADDED
File without changes
|
dreamcoder/domains/logo/logoPrimitives.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.program import Primitive, Program
|
2 |
+
from dreamcoder.type import arrow, baseType, tint
|
3 |
+
|
4 |
+
turtle = baseType("turtle")
|
5 |
+
tstate = baseType("tstate")
|
6 |
+
tangle = baseType("tangle")
|
7 |
+
tlength = baseType("tlength")
|
8 |
+
|
9 |
+
primitives = [
|
10 |
+
Primitive("logo_UA", tangle, ""),
|
11 |
+
Primitive("logo_UL", tlength, ""),
|
12 |
+
|
13 |
+
Primitive("logo_ZA", tangle, ""),
|
14 |
+
Primitive("logo_ZL", tlength, ""),
|
15 |
+
|
16 |
+
Primitive("logo_DIVA", arrow(tangle,tint,tangle), ""),
|
17 |
+
Primitive("logo_MULA", arrow(tangle,tint,tangle), ""),
|
18 |
+
Primitive("logo_DIVL", arrow(tlength,tint,tlength), ""),
|
19 |
+
Primitive("logo_MULL", arrow(tlength,tint,tlength), ""),
|
20 |
+
|
21 |
+
Primitive("logo_ADDA", arrow(tangle,tangle,tangle), ""),
|
22 |
+
Primitive("logo_SUBA", arrow(tangle,tangle,tangle), ""),
|
23 |
+
# Primitive("logo_ADDL", arrow(tlength,tlength,tlength), ""),
|
24 |
+
# Primitive("logo_SUBL", arrow(tlength,tlength,tlength), ""),
|
25 |
+
|
26 |
+
# Primitive("logo_PU", arrow(turtle,turtle), ""),
|
27 |
+
# Primitive("logo_PD", arrow(turtle,turtle), ""),
|
28 |
+
Primitive("logo_PT", arrow(arrow(turtle,turtle),arrow(turtle,turtle)), None),
|
29 |
+
Primitive("logo_FWRT", arrow(tlength,tangle,turtle,turtle), ""),
|
30 |
+
Primitive("logo_GETSET", arrow(arrow(turtle,turtle),turtle,turtle), "")
|
31 |
+
] + [
|
32 |
+
Primitive("logo_IFTY", tint, ""),
|
33 |
+
Primitive("logo_epsA", tangle, ""),
|
34 |
+
Primitive("logo_epsL", tlength, ""),
|
35 |
+
Primitive("logo_forLoop", arrow(tint, arrow(tint, turtle, turtle), turtle, turtle), "ERROR: python has no way of expressing this hence you shouldn't eval on this"),
|
36 |
+
] + [Primitive(str(j), tint, j) for j in range(10)]
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
expr_s = "(lambda (logo_forLoop 3 (lambda (lambda (logo_GET (lambda (logo_FWRT (logo_S2L (logo_I2S 1)) (logo_S2A (logo_I2S 0)) (logo_SET $0 (logo_FWRT (logo_S2L eps) (logo_DIVA (logo_S2A (logo_I2S 2)) (logo_I2S 3)) ($1)))))))) ($0)))"
|
40 |
+
x = Program.parse(expr_s)
|
41 |
+
print(x)
|
dreamcoder/domains/logo/main.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
import random as random
|
7 |
+
import subprocess
|
8 |
+
import sys
|
9 |
+
import time
|
10 |
+
|
11 |
+
try:
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
except:
|
17 |
+
print("WARNING: Could not import torch. This is only okay when doing pypy compression.",
|
18 |
+
file=sys.stderr)
|
19 |
+
|
20 |
+
from dreamcoder.domains.logo.makeLogoTasks import makeTasks, montageTasks, drawLogo
|
21 |
+
from dreamcoder.domains.logo.logoPrimitives import primitives, turtle, tangle, tlength
|
22 |
+
from dreamcoder.dreamcoder import ecIterator
|
23 |
+
from dreamcoder.grammar import Grammar
|
24 |
+
from dreamcoder.program import Program
|
25 |
+
try:
|
26 |
+
from dreamcoder.recognition import variable, maybe_cuda
|
27 |
+
except:
|
28 |
+
print("WARNING: Could not import recognition. This is only okay when doing pypy compression.",
|
29 |
+
file=sys.stderr)
|
30 |
+
from dreamcoder.task import Task
|
31 |
+
from dreamcoder.type import arrow
|
32 |
+
from dreamcoder.utilities import eprint, testTrainSplit, loadPickle
|
33 |
+
|
34 |
+
|
35 |
+
def animateSolutions(allFrontiers):
|
36 |
+
programs = []
|
37 |
+
filenames = []
|
38 |
+
for n,(t,f) in enumerate(allFrontiers.items()):
|
39 |
+
if f.empty: continue
|
40 |
+
|
41 |
+
programs.append(f.bestPosterior.program)
|
42 |
+
filenames.append(f"/tmp/logo_animation_{n}")
|
43 |
+
|
44 |
+
drawLogo(*programs, pretty=True, smoothPretty=True, resolution=128, animate=True,
|
45 |
+
filenames=filenames)
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
def dreamFromGrammar(g, directory, N=100):
|
50 |
+
if isinstance(g,Grammar):
|
51 |
+
programs = [ p
|
52 |
+
for _ in range(N)
|
53 |
+
for p in [g.sample(arrow(turtle,turtle),
|
54 |
+
maximumDepth=20)]
|
55 |
+
if p is not None]
|
56 |
+
else:
|
57 |
+
programs = g
|
58 |
+
drawLogo(*programs,
|
59 |
+
pretty=False, smoothPretty=False,
|
60 |
+
resolution=512,
|
61 |
+
filenames=[f"{directory}/{n}.png" for n in range(len(programs)) ],
|
62 |
+
timeout=1)
|
63 |
+
drawLogo(*programs,
|
64 |
+
pretty=True, smoothPretty=False,
|
65 |
+
resolution=512,
|
66 |
+
filenames=[f"{directory}/{n}_pretty.png" for n in range(len(programs)) ],
|
67 |
+
timeout=1)
|
68 |
+
drawLogo(*programs,
|
69 |
+
pretty=False, smoothPretty=True,
|
70 |
+
resolution=512,
|
71 |
+
filenames=[f"{directory}/{n}_smooth_pretty.png" for n in range(len(programs)) ],
|
72 |
+
timeout=1)
|
73 |
+
for n,p in enumerate(programs):
|
74 |
+
with open(f"{directory}/{n}.dream","w") as handle:
|
75 |
+
handle.write(str(p))
|
76 |
+
|
77 |
+
try:
|
78 |
+
class Flatten(nn.Module):
|
79 |
+
def __init__(self):
|
80 |
+
super(Flatten, self).__init__()
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
return x.view(x.size(0), -1)
|
84 |
+
|
85 |
+
|
86 |
+
class LogoFeatureCNN(nn.Module):
|
87 |
+
special = "LOGO"
|
88 |
+
|
89 |
+
def __init__(self, tasks, testingTasks=[], cuda=False, H=64):
|
90 |
+
super(LogoFeatureCNN, self).__init__()
|
91 |
+
|
92 |
+
self.sub = prefix_dreams + str(int(time.time()))
|
93 |
+
|
94 |
+
self.recomputeTasks = False
|
95 |
+
|
96 |
+
def conv_block(in_channels, out_channels, p=True):
|
97 |
+
return nn.Sequential(
|
98 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
99 |
+
# nn.BatchNorm2d(out_channels),
|
100 |
+
nn.ReLU(),
|
101 |
+
# nn.Conv2d(out_channels, out_channels, 3, padding=1),
|
102 |
+
# nn.ReLU(),
|
103 |
+
nn.MaxPool2d(2))
|
104 |
+
|
105 |
+
self.inputImageDimension = 128
|
106 |
+
self.resizedDimension = 128
|
107 |
+
assert self.inputImageDimension % self.resizedDimension == 0
|
108 |
+
|
109 |
+
# channels for hidden
|
110 |
+
hid_dim = 64
|
111 |
+
z_dim = 64
|
112 |
+
|
113 |
+
self.encoder = nn.Sequential(
|
114 |
+
conv_block(1, hid_dim),
|
115 |
+
conv_block(hid_dim, hid_dim),
|
116 |
+
conv_block(hid_dim, hid_dim),
|
117 |
+
conv_block(hid_dim, hid_dim),
|
118 |
+
conv_block(hid_dim, hid_dim),
|
119 |
+
conv_block(hid_dim, z_dim),
|
120 |
+
Flatten()
|
121 |
+
)
|
122 |
+
|
123 |
+
self.outputDimensionality = 256
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
def forward(self, v):
|
129 |
+
assert len(v) == self.inputImageDimension*self.inputImageDimension
|
130 |
+
floatOnlyTask = list(map(float, v))
|
131 |
+
reshaped = [floatOnlyTask[i:i + self.inputImageDimension]
|
132 |
+
for i in range(0, len(floatOnlyTask), self.inputImageDimension)]
|
133 |
+
v = variable(reshaped).float()
|
134 |
+
# insert channel and batch
|
135 |
+
v = torch.unsqueeze(v, 0)
|
136 |
+
v = torch.unsqueeze(v, 0)
|
137 |
+
v = maybe_cuda(v, next(self.parameters()).is_cuda)/256.
|
138 |
+
window = int(self.inputImageDimension/self.resizedDimension)
|
139 |
+
v = F.avg_pool2d(v, (window,window))
|
140 |
+
v = self.encoder(v)
|
141 |
+
return v.view(-1)
|
142 |
+
|
143 |
+
def featuresOfTask(self, t): # Take a task and returns [features]
|
144 |
+
return self(t.highresolution)
|
145 |
+
|
146 |
+
def tasksOfPrograms(self, ps, types):
|
147 |
+
images = drawLogo(*ps, resolution=128)
|
148 |
+
if len(ps) == 1: images = [images]
|
149 |
+
tasks = []
|
150 |
+
for i in images:
|
151 |
+
if isinstance(i, str): tasks.append(None)
|
152 |
+
else:
|
153 |
+
t = Task("Helm", arrow(turtle,turtle), [])
|
154 |
+
t.highresolution = i
|
155 |
+
tasks.append(t)
|
156 |
+
return tasks
|
157 |
+
|
158 |
+
def taskOfProgram(self, p, t):
|
159 |
+
return self.tasksOfPrograms([p], None)[0]
|
160 |
+
except:
|
161 |
+
pass
|
162 |
+
|
163 |
+
def list_options(parser):
|
164 |
+
parser.add_argument("--proto",
|
165 |
+
default=False,
|
166 |
+
action="store_true",
|
167 |
+
help="Should we use prototypical networks?")
|
168 |
+
parser.add_argument("--target", type=str,
|
169 |
+
default=[],
|
170 |
+
action='append',
|
171 |
+
help="Which tasks should this try to solve")
|
172 |
+
parser.add_argument("--reduce", type=str,
|
173 |
+
default=[],
|
174 |
+
action='append',
|
175 |
+
help="Which tasks should this try to solve")
|
176 |
+
parser.add_argument("--save", type=str,
|
177 |
+
default=None,
|
178 |
+
help="Filepath output the grammar if this is a child")
|
179 |
+
parser.add_argument("--prefix", type=str,
|
180 |
+
default="experimentOutputs/",
|
181 |
+
help="Filepath output the grammar if this is a child")
|
182 |
+
parser.add_argument("--dreamCheckpoint", type=str,
|
183 |
+
default=None,
|
184 |
+
help="File to load in order to get dreams")
|
185 |
+
parser.add_argument("--dreamDirectory", type=str,
|
186 |
+
default=None,
|
187 |
+
help="Directory in which to dream from --dreamCheckpoint")
|
188 |
+
parser.add_argument("--visualize",
|
189 |
+
default=None, type=str)
|
190 |
+
parser.add_argument("--cost", default=False, action='store_true',
|
191 |
+
help="Impose a smooth cost on using ink")
|
192 |
+
parser.add_argument("--split",
|
193 |
+
default=1., type=float)
|
194 |
+
parser.add_argument("--animate",
|
195 |
+
default=None, type=str)
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
def outputDreams(checkpoint, directory):
|
200 |
+
from dreamcoder.utilities import loadPickle
|
201 |
+
result = loadPickle(checkpoint)
|
202 |
+
eprint(" [+] Loaded checkpoint",checkpoint)
|
203 |
+
g = result.grammars[-1]
|
204 |
+
if directory is None:
|
205 |
+
randomStr = ''.join(random.choice('0123456789') for _ in range(10))
|
206 |
+
directory = "/tmp/" + randomStr
|
207 |
+
eprint(" Dreaming into",directory)
|
208 |
+
os.system("mkdir -p %s"%directory)
|
209 |
+
dreamFromGrammar(g, directory)
|
210 |
+
|
211 |
+
def enumerateDreams(checkpoint, directory):
|
212 |
+
from dreamcoder.dreaming import backgroundHelmholtzEnumeration
|
213 |
+
from dreamcoder.utilities import loadPickle
|
214 |
+
result = loadPickle(checkpoint)
|
215 |
+
eprint(" [+] Loaded checkpoint",checkpoint)
|
216 |
+
g = result.grammars[-1]
|
217 |
+
if directory is None: assert False, "please specify a directory"
|
218 |
+
eprint(" Dreaming into",directory)
|
219 |
+
os.system("mkdir -p %s"%directory)
|
220 |
+
frontiers = backgroundHelmholtzEnumeration(makeTasks(None,None), g, 100,
|
221 |
+
evaluationTimeout=0.01,
|
222 |
+
special=LogoFeatureCNN.special)()
|
223 |
+
print(f"{len(frontiers)} total frontiers.")
|
224 |
+
MDL = 0
|
225 |
+
def L(f):
|
226 |
+
return -list(f.entries)[0].logPrior
|
227 |
+
frontiers.sort(key=lambda f: -L(f))
|
228 |
+
while len(frontiers) > 0:
|
229 |
+
# get frontiers whose MDL is between [MDL,MDL + 1)
|
230 |
+
fs = []
|
231 |
+
while len(frontiers) > 0 and L(frontiers[-1]) < MDL + 1:
|
232 |
+
fs.append(frontiers.pop(len(frontiers) - 1))
|
233 |
+
if fs:
|
234 |
+
random.shuffle(fs)
|
235 |
+
print(f"{len(fs)} programs with MDL between [{MDL}, {MDL + 1})")
|
236 |
+
|
237 |
+
fs = fs[:500]
|
238 |
+
os.system(f"mkdir {directory}/{MDL}")
|
239 |
+
dreamFromGrammar([list(f.entries)[0].program for f in fs],
|
240 |
+
f"{directory}/{MDL}")
|
241 |
+
MDL += 1
|
242 |
+
|
243 |
+
def visualizePrimitives(primitives, export='/tmp/logo_primitives.png'):
|
244 |
+
from itertools import product
|
245 |
+
from dreamcoder.program import Index,Abstraction,Application
|
246 |
+
from dreamcoder.utilities import montageMatrix,makeNiceArray
|
247 |
+
from dreamcoder.type import tint
|
248 |
+
import scipy.misc
|
249 |
+
from dreamcoder.domains.logo.makeLogoTasks import parseLogo
|
250 |
+
|
251 |
+
angles = [Program.parse(a)
|
252 |
+
for a in ["logo_ZA",
|
253 |
+
"logo_epsA",
|
254 |
+
"(logo_MULA logo_epsA 2)",
|
255 |
+
"(logo_DIVA logo_UA 4)",
|
256 |
+
"(logo_DIVA logo_UA 5)",
|
257 |
+
"(logo_DIVA logo_UA 7)",
|
258 |
+
"(logo_DIVA logo_UA 9)",
|
259 |
+
] ]
|
260 |
+
specialAngles = {"#(lambda (lambda (logo_forLoop logo_IFTY (lambda (lambda (logo_FWRT (logo_MULL logo_UL 3) (logo_MULA $2 4) $0))) $1)))":
|
261 |
+
[Program.parse("(logo_MULA logo_epsA 4)")]+[Program.parse("(logo_DIVA logo_UA %d)"%n) for n in [7,9] ]}
|
262 |
+
numbers = [Program.parse(n)
|
263 |
+
for n in ["1","2","5","7","logo_IFTY"] ]
|
264 |
+
specialNumbers = {"#(lambda (#(lambda (lambda (lambda (lambda (logo_forLoop $2 (lambda (lambda (logo_FWRT $5 (logo_DIVA logo_UA $3) $0))) $0))))) (logo_MULL logo_UL $0) 4 4))":
|
265 |
+
[Program.parse(str(n)) for n in [1,2,3] ]}
|
266 |
+
distances = [Program.parse(l)
|
267 |
+
for l in ["logo_ZL",
|
268 |
+
"logo_epsL",
|
269 |
+
"(logo_MULL logo_epsL 2)",
|
270 |
+
"(logo_DIVL logo_UL 2)",
|
271 |
+
"logo_UL"] ]
|
272 |
+
subprograms = [parseLogo(sp)
|
273 |
+
for sp in ["(move 1d 0a)",
|
274 |
+
"(loop i infinity (move (*l epsilonLength 4) (*a epsilonAngle 2)))",
|
275 |
+
"(loop i infinity (move (*l epsilonLength 5) (/a epsilonAngle 2)))",
|
276 |
+
"(loop i 4 (move 1d (/a 1a 4)))"]]
|
277 |
+
|
278 |
+
entireArguments = {"#(lambda (lambda (#(#(lambda (lambda (lambda (logo_forLoop $2 (lambda (lambda (logo_FWRT $2 $3 $0))))))) logo_IFTY) (logo_MULA (#(logo_DIVA logo_UA) $1) $0) (#(logo_MULL logo_UL) 3))))":
|
279 |
+
[[Program.parse(str(x)) for x in xs ]
|
280 |
+
for xs in [("3", "1", "$0"),
|
281 |
+
("4", "1", "$0"),
|
282 |
+
("5", "1", "$0"),
|
283 |
+
("5", "3", "$0"),
|
284 |
+
("7", "3", "$0")]]}
|
285 |
+
specialDistances = {"#(lambda (lambda (logo_forLoop 7 (lambda (lambda (#(lambda (lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $2 (lambda (lambda (logo_FWRT $2 $3 $0))))))) 7 $1 $2 $0)))) $3 logo_epsA $0))) $0)))":
|
286 |
+
[Program.parse("(logo_MULL logo_epsL %d)"%n) for n in range(5)]}
|
287 |
+
|
288 |
+
matrix = []
|
289 |
+
for p in primitives:
|
290 |
+
if not p.isInvented: continue
|
291 |
+
t = p.tp
|
292 |
+
eprint(p,":",p.tp)
|
293 |
+
if t.returns() != turtle:
|
294 |
+
eprint("\t(does not return a turtle)")
|
295 |
+
continue
|
296 |
+
|
297 |
+
def argumentChoices(t):
|
298 |
+
if t == turtle:
|
299 |
+
return [Index(0)]
|
300 |
+
elif t == arrow(turtle,turtle):
|
301 |
+
return subprograms
|
302 |
+
elif t == tint:
|
303 |
+
return specialNumbers.get(str(p),numbers)
|
304 |
+
elif t == tangle:
|
305 |
+
return specialAngles.get(str(p),angles)
|
306 |
+
elif t == tlength:
|
307 |
+
return specialDistances.get(str(p),distances)
|
308 |
+
else: return []
|
309 |
+
|
310 |
+
ts = []
|
311 |
+
for arguments in entireArguments.get(str(p),product(*[argumentChoices(t) for t in t.functionArguments() ])):
|
312 |
+
eprint(arguments)
|
313 |
+
pp = p
|
314 |
+
for a in arguments: pp = Application(pp,a)
|
315 |
+
pp = Abstraction(pp)
|
316 |
+
i = np.reshape(np.array(drawLogo(pp, resolution=128)), (128,128))
|
317 |
+
if i is not None:
|
318 |
+
ts.append(i)
|
319 |
+
|
320 |
+
|
321 |
+
if ts == []: continue
|
322 |
+
|
323 |
+
matrix.append(ts)
|
324 |
+
if len(ts) < 6: ts = [ts]
|
325 |
+
else: ts = makeNiceArray(ts)
|
326 |
+
r = montageMatrix(ts)
|
327 |
+
fn = "/tmp/logo_primitive_%d.png"%len(matrix)
|
328 |
+
eprint("\tExported to",fn)
|
329 |
+
scipy.misc.imsave(fn, r)
|
330 |
+
|
331 |
+
matrix = montageMatrix(matrix)
|
332 |
+
scipy.misc.imsave(export, matrix)
|
333 |
+
|
334 |
+
|
335 |
+
def main(args):
|
336 |
+
"""
|
337 |
+
Takes the return value of the `commandlineArguments()` function as input and
|
338 |
+
trains/tests the model on LOGO tasks.
|
339 |
+
"""
|
340 |
+
|
341 |
+
# The below legacy global statement is required since prefix_dreams is used by LogoFeatureCNN.
|
342 |
+
# TODO(lcary): use argument passing instead of global variables.
|
343 |
+
global prefix_dreams
|
344 |
+
|
345 |
+
# The below global statement is required since primitives is modified within main().
|
346 |
+
# TODO(lcary): use a function call to retrieve and declare primitives instead.
|
347 |
+
global primitives
|
348 |
+
|
349 |
+
visualizeCheckpoint = args.pop("visualize")
|
350 |
+
if visualizeCheckpoint is not None:
|
351 |
+
with open(visualizeCheckpoint,'rb') as handle:
|
352 |
+
primitives = pickle.load(handle).grammars[-1].primitives
|
353 |
+
visualizePrimitives(primitives)
|
354 |
+
sys.exit(0)
|
355 |
+
|
356 |
+
dreamCheckpoint = args.pop("dreamCheckpoint")
|
357 |
+
dreamDirectory = args.pop("dreamDirectory")
|
358 |
+
|
359 |
+
proto = args.pop("proto")
|
360 |
+
|
361 |
+
if dreamCheckpoint is not None:
|
362 |
+
#outputDreams(dreamCheckpoint, dreamDirectory)
|
363 |
+
enumerateDreams(dreamCheckpoint, dreamDirectory)
|
364 |
+
sys.exit(0)
|
365 |
+
|
366 |
+
animateCheckpoint = args.pop("animate")
|
367 |
+
if animateCheckpoint is not None:
|
368 |
+
animateSolutions(loadPickle(animateCheckpoint).allFrontiers)
|
369 |
+
sys.exit(0)
|
370 |
+
|
371 |
+
target = args.pop("target")
|
372 |
+
red = args.pop("reduce")
|
373 |
+
save = args.pop("save")
|
374 |
+
prefix = args.pop("prefix")
|
375 |
+
prefix_dreams = prefix + "/dreams/" + ('_'.join(target)) + "/"
|
376 |
+
prefix_pickles = prefix + "/logo." + ('.'.join(target))
|
377 |
+
if not os.path.exists(prefix_dreams):
|
378 |
+
os.makedirs(prefix_dreams)
|
379 |
+
tasks = makeTasks(target, proto)
|
380 |
+
eprint("Generated", len(tasks), "tasks")
|
381 |
+
|
382 |
+
costMatters = args.pop("cost")
|
383 |
+
for t in tasks:
|
384 |
+
t.specialTask[1]["costMatters"] = costMatters
|
385 |
+
# disgusting hack - include whether cost matters in the dummy input
|
386 |
+
if costMatters: t.examples = [(([1]), t.examples[0][1])]
|
387 |
+
|
388 |
+
os.chdir("prototypical-networks")
|
389 |
+
subprocess.Popen(["python","./protonet_server.py"])
|
390 |
+
time.sleep(3)
|
391 |
+
os.chdir("..")
|
392 |
+
|
393 |
+
|
394 |
+
test, train = testTrainSplit(tasks, args.pop("split"))
|
395 |
+
eprint("Split tasks into %d/%d test/train" % (len(test), len(train)))
|
396 |
+
try:
|
397 |
+
if test: montageTasks(test,"test_")
|
398 |
+
montageTasks(train,"train_")
|
399 |
+
except:
|
400 |
+
eprint("WARNING: couldn't generate montage. Do you have an old version of scipy?")
|
401 |
+
|
402 |
+
if red is not []:
|
403 |
+
for reducing in red:
|
404 |
+
try:
|
405 |
+
with open(reducing, 'r') as f:
|
406 |
+
prods = json.load(f)
|
407 |
+
for e in prods:
|
408 |
+
e = Program.parse(e)
|
409 |
+
if e.isInvented:
|
410 |
+
primitives.append(e)
|
411 |
+
except EOFError:
|
412 |
+
eprint("Couldn't grab frontier from " + reducing)
|
413 |
+
except IOError:
|
414 |
+
eprint("Couldn't grab frontier from " + reducing)
|
415 |
+
except json.decoder.JSONDecodeError:
|
416 |
+
eprint("Couldn't grab frontier from " + reducing)
|
417 |
+
|
418 |
+
primitives = list(OrderedDict((x, True) for x in primitives).keys())
|
419 |
+
baseGrammar = Grammar.uniform(primitives, continuationType=turtle)
|
420 |
+
|
421 |
+
eprint(baseGrammar)
|
422 |
+
|
423 |
+
timestamp = datetime.datetime.now().isoformat()
|
424 |
+
outputDirectory = "experimentOutputs/logo/%s"%timestamp
|
425 |
+
os.system("mkdir -p %s"%outputDirectory)
|
426 |
+
|
427 |
+
|
428 |
+
generator = ecIterator(baseGrammar, train,
|
429 |
+
testingTasks=test,
|
430 |
+
outputPrefix="%s/logo"%outputDirectory,
|
431 |
+
evaluationTimeout=0.01,
|
432 |
+
**args)
|
433 |
+
|
434 |
+
r = None
|
435 |
+
for result in generator:
|
436 |
+
iteration = len(result.learningCurve)
|
437 |
+
dreamDirectory = "%s/dreams_%d"%(outputDirectory, iteration)
|
438 |
+
os.system("mkdir -p %s"%dreamDirectory)
|
439 |
+
eprint("Dreaming into directory",dreamDirectory)
|
440 |
+
dreamFromGrammar(result.grammars[-1],
|
441 |
+
dreamDirectory)
|
442 |
+
r = result
|
443 |
+
|
444 |
+
needsExport = [str(z)
|
445 |
+
for _, _, z
|
446 |
+
in r.grammars[-1].productions
|
447 |
+
if z.isInvented]
|
448 |
+
if save is not None:
|
449 |
+
with open(save, 'w') as f:
|
450 |
+
json.dump(needsExport, f)
|
dreamcoder/domains/logo/makeLogoTasks.py
ADDED
@@ -0,0 +1,777 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf8
|
2 |
+
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
|
7 |
+
from dreamcoder.domains.logo.logoPrimitives import primitives, turtle
|
8 |
+
from dreamcoder.task import Task
|
9 |
+
from dreamcoder.program import Abstraction, Application, Index, Program
|
10 |
+
from dreamcoder.type import arrow
|
11 |
+
from dreamcoder.utilities import eprint, jsonBinaryInvoke, random_seed, montage
|
12 |
+
from dreamcoder.grammar import Grammar
|
13 |
+
|
14 |
+
|
15 |
+
def drawLogo(*programs,
|
16 |
+
timeout=None,
|
17 |
+
resolution=None,
|
18 |
+
pretty=False, smoothPretty=False,
|
19 |
+
filenames=[],
|
20 |
+
animate=False,
|
21 |
+
cost=False):
|
22 |
+
message = {}
|
23 |
+
if pretty: message["pretty"] = pretty
|
24 |
+
if smoothPretty: message["smoothPretty"] = smoothPretty
|
25 |
+
if timeout: message["timeout"] = timeout
|
26 |
+
assert resolution is not None, "resolution not provided in drawLogo"
|
27 |
+
if isinstance(resolution, list):
|
28 |
+
assert len(resolution) == len(programs), "must provide a resolution for each program"
|
29 |
+
elif isinstance(resolution, int):
|
30 |
+
resolution = [resolution]*len(programs)
|
31 |
+
else: assert False
|
32 |
+
jobs = []
|
33 |
+
for p, size in zip(programs, resolution):
|
34 |
+
entry = {"program": str(p),
|
35 |
+
"size": size}
|
36 |
+
if animate: entry["animate"] = True
|
37 |
+
if len(filenames) > 0:
|
38 |
+
entry["export"] = filenames[0]
|
39 |
+
filenames = filenames[1:]
|
40 |
+
jobs.append(entry)
|
41 |
+
message["jobs"] = jobs
|
42 |
+
response = jsonBinaryInvoke("./logoDrawString", message)
|
43 |
+
if cost:
|
44 |
+
# include the cost and return tuples of (pixels, cost)
|
45 |
+
response = [programResponse if isinstance(programResponse,str) else (programResponse["pixels"], programResponse["cost"])
|
46 |
+
for programResponse in response ]
|
47 |
+
else:
|
48 |
+
response = [programResponse if isinstance(programResponse,str) else programResponse["pixels"]
|
49 |
+
for programResponse in response ]
|
50 |
+
if len(programs) == 1:
|
51 |
+
return response[0]
|
52 |
+
return response
|
53 |
+
|
54 |
+
def makeTasks(subfolders, proto):
|
55 |
+
return manualLogoTasks()
|
56 |
+
|
57 |
+
def parseLogo(s):
|
58 |
+
|
59 |
+
_ua = Program.parse("logo_UA")
|
60 |
+
_ul = Program.parse("logo_UL")
|
61 |
+
|
62 |
+
_za = Program.parse("logo_ZA")
|
63 |
+
_zl = Program.parse("logo_ZL")
|
64 |
+
|
65 |
+
_da = Program.parse("logo_DIVA")
|
66 |
+
_ma = Program.parse("logo_MULA")
|
67 |
+
_dl = Program.parse("logo_DIVL")
|
68 |
+
_ml = Program.parse("logo_MULL")
|
69 |
+
|
70 |
+
_aa = Program.parse("logo_ADDA")
|
71 |
+
_sa = Program.parse("logo_SUBA")
|
72 |
+
_al = None#Program.parse("logo_ADDL")
|
73 |
+
_sl = None#Program.parse("logo_SUBL")
|
74 |
+
|
75 |
+
_pu = None#Program.parse("logo_PU")
|
76 |
+
_pd = None#Program.parse("logo_PD")
|
77 |
+
_p = Program.parse("logo_PT")
|
78 |
+
_move = Program.parse("logo_FWRT")
|
79 |
+
_embed = Program.parse("logo_GETSET")
|
80 |
+
|
81 |
+
_addition = Program.parse("+")
|
82 |
+
_infinity = Program.parse("logo_IFTY")
|
83 |
+
_ea = Program.parse("logo_epsA")
|
84 |
+
_el = Program.parse("logo_epsL")
|
85 |
+
_loop = Program.parse("logo_forLoop")
|
86 |
+
|
87 |
+
from sexpdata import loads, Symbol
|
88 |
+
s = loads(s)
|
89 |
+
def command(k, environment, continuation):
|
90 |
+
assert isinstance(k,list)
|
91 |
+
if k[0] == Symbol("move"):
|
92 |
+
return Application(Application(Application(_move,
|
93 |
+
expression(k[1],environment)),
|
94 |
+
expression(k[2],environment)),
|
95 |
+
continuation)
|
96 |
+
if k[0] == Symbol("for") or k[0] == Symbol("loop"):
|
97 |
+
v = k[1]
|
98 |
+
b = expression(k[2], environment)
|
99 |
+
newEnvironment = [None, v] + environment
|
100 |
+
body = block(k[3:], newEnvironment, Index(0))
|
101 |
+
return Application(Application(Application(_loop,b),
|
102 |
+
Abstraction(Abstraction(body))),
|
103 |
+
continuation)
|
104 |
+
if k[0] == Symbol("embed"):
|
105 |
+
body = block(k[1:], [None] + environment, Index(0))
|
106 |
+
return Application(Application(_embed,Abstraction(body)),continuation)
|
107 |
+
if k[0] == Symbol("p"):
|
108 |
+
body = block(k[1:], [None] + environment, Index(0))
|
109 |
+
return Application(Application(_p,Abstraction(body)),continuation)
|
110 |
+
|
111 |
+
assert False
|
112 |
+
def expression(e, environment):
|
113 |
+
for n, v in enumerate(environment):
|
114 |
+
if e == v: return Index(n)
|
115 |
+
|
116 |
+
if isinstance(e,int): return Program.parse(str(e))
|
117 |
+
|
118 |
+
mapping = {"1a": _ua,
|
119 |
+
"1d": _ul, "1l": _ul,
|
120 |
+
"0a": _za,
|
121 |
+
"0d": _zl, "0l": _zl,
|
122 |
+
"/a": _da,
|
123 |
+
"/l": _dl, "/d": _dl,
|
124 |
+
"*a": _ma,
|
125 |
+
"*l": _ml, "*d": _ml,
|
126 |
+
"+a": _aa,
|
127 |
+
"+d": _al, "+l": _al,
|
128 |
+
"-a": _sa,
|
129 |
+
"-d": _sl, "-l": _sl,
|
130 |
+
"+": _addition,
|
131 |
+
"infinity": _infinity,
|
132 |
+
"epsilonAngle": _ea,
|
133 |
+
"epsilonDistance": _el,
|
134 |
+
"epsilonLength": _el}
|
135 |
+
if e == float('inf'): return _infinity
|
136 |
+
for name, value in mapping.items():
|
137 |
+
if e == Symbol(name): return value
|
138 |
+
|
139 |
+
assert isinstance(e,list), "not a list %s"%e
|
140 |
+
for name, value in mapping.items():
|
141 |
+
if e[0] == Symbol(name):
|
142 |
+
f = value
|
143 |
+
for argument in e[1:]:
|
144 |
+
f = Application(f, expression(argument, environment))
|
145 |
+
return f
|
146 |
+
assert False
|
147 |
+
|
148 |
+
def block(b, environment, continuation):
|
149 |
+
if len(b) == 0: return continuation
|
150 |
+
return command(b[0], environment, block(b[1:], environment, continuation))
|
151 |
+
|
152 |
+
try: return Abstraction(command(s, [], Index(0)))
|
153 |
+
except: return Abstraction(block(s, [], Index(0)))
|
154 |
+
|
155 |
+
|
156 |
+
def manualLogoTask(name, expression, proto=False, needToTrain=False,
|
157 |
+
supervise=False, lambdaCalculus=False):
|
158 |
+
p = Program.parse(expression) if lambdaCalculus else parseLogo(expression)
|
159 |
+
from dreamcoder.domains.logo.logoPrimitives import primitives
|
160 |
+
from dreamcoder.grammar import Grammar
|
161 |
+
g = Grammar.uniform(primitives, continuationType=turtle)
|
162 |
+
gp = Grammar.uniform(primitives)
|
163 |
+
try:
|
164 |
+
l = g.logLikelihood(arrow(turtle,turtle),p)
|
165 |
+
lp = gp.logLikelihood(arrow(turtle,turtle),p)
|
166 |
+
assert l >= lp
|
167 |
+
eprint(name,-l,"nats")
|
168 |
+
|
169 |
+
except: eprint("WARNING: could not calculate likelihood of manual logo",p)
|
170 |
+
|
171 |
+
attempts = 0
|
172 |
+
while True:
|
173 |
+
[output, highresolution] = drawLogo(p, p, resolution=[28,128], cost=True)
|
174 |
+
if output == "timeout" or highresolution == "timeout":
|
175 |
+
attempts += 1
|
176 |
+
else:
|
177 |
+
break
|
178 |
+
if attempts > 0:
|
179 |
+
eprint(f"WARNING: Took {attempts} attempts to render task {name} within timeout")
|
180 |
+
|
181 |
+
cost = output[1]
|
182 |
+
output = output[0]
|
183 |
+
assert highresolution[1] == cost
|
184 |
+
highresolution = highresolution[0]
|
185 |
+
|
186 |
+
shape = list(map(int, output))
|
187 |
+
highresolution = list(map(float, highresolution))
|
188 |
+
t = Task(name, arrow(turtle,turtle),
|
189 |
+
[(([0]), shape)])
|
190 |
+
t.mustTrain = needToTrain
|
191 |
+
t.proto = proto
|
192 |
+
t.specialTask = ("LOGO", {"proto": proto})
|
193 |
+
t.specialTask[1]["cost"] = cost*1.05
|
194 |
+
|
195 |
+
t.highresolution = highresolution
|
196 |
+
|
197 |
+
if supervise:
|
198 |
+
t.supervisedSolution = p
|
199 |
+
|
200 |
+
return t
|
201 |
+
|
202 |
+
def dSLDemo():
|
203 |
+
n = 0
|
204 |
+
demos = []
|
205 |
+
def T(source):
|
206 |
+
demos.append(manualLogoTask(str(len(demos)), source,
|
207 |
+
lambdaCalculus="lambda" in source))
|
208 |
+
# this looks like polygons - verify and include
|
209 |
+
T("(#(lambda (lambda (#(lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $0 (lambda (lambda (logo_FWRT $4 $3 $0))))))) $1 $0 logo_IFTY))) $1 (logo_DIVA logo_UA $0)))) (logo_MULL logo_UL 4) 3)")
|
210 |
+
T("(#(lambda (lambda (#(lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $0 (lambda (lambda (logo_FWRT $4 $3 $0))))))) $1 $0 logo_IFTY))) $1 (logo_DIVA logo_UA $0)))) (logo_MULL logo_UL 6) 4)")
|
211 |
+
T("(#(lambda (lambda (#(lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $0 (lambda (lambda (logo_FWRT $4 $3 $0))))))) $1 $0 logo_IFTY))) $1 (logo_DIVA logo_UA $0)))) (logo_MULL logo_UL 5) 5)")
|
212 |
+
T("(#(lambda (lambda (#(lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $0 (lambda (lambda (logo_FWRT $4 $3 $0))))))) $1 $0 logo_IFTY))) $1 (logo_DIVA logo_UA $0)))) (logo_MULL logo_UL 3) 6)")
|
213 |
+
T("(#(lambda (lambda (#(lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $0 (lambda (lambda (logo_FWRT $4 $3 $0))))))) $1 $0 logo_IFTY))) $1 (logo_DIVA logo_UA $0)))) (logo_MULL logo_UL 2) 7)")
|
214 |
+
|
215 |
+
# Spirals!
|
216 |
+
for spiralSize in [1,2,3,4,5]:
|
217 |
+
T(f"((lambda (logo_forLoop logo_IFTY (lambda (lambda (logo_FWRT (logo_MULL logo_epsL $1) (logo_MULA logo_epsA $2) $0))))) {spiralSize})")
|
218 |
+
for spiralSize in [5,6,7,8,9]:
|
219 |
+
#T(f"(lambda (#(lambda (logo_forLoop $0 (lambda (lambda (#(lambda (logo_FWRT (logo_MULL logo_UL $0) (logo_DIVA logo_UA 4))) $1 $0))))) {spiralSize} $0))")
|
220 |
+
T("(loop i " + str(spiralSize) + " (move (*d 1l i) (/a 1a 4)))")# (#(lambda (logo_forLoop $0 (lambda (lambda (#(lambda (logo_FWRT (logo_MULL logo_UL $0) (logo_DIVA logo_UA 4))) $1 $0))))) {spiralSize} $0))")
|
221 |
+
|
222 |
+
# CIRCLES
|
223 |
+
#(lambda (#(lambda (logo_forLoop 6 (lambda (lambda (#(lambda (lambda (logo_forLoop logo_IFTY (lambda (lambda (logo_FWRT $2 $3 $0)))))) logo_epsA (logo_MULL logo_epsL $2) $0))))) 6 $0))
|
224 |
+
for circleSize in [1,3,5,7,9]:
|
225 |
+
T(f"(lambda (#(lambda (logo_forLoop 6 (lambda (lambda (#(lambda (lambda (logo_forLoop logo_IFTY (lambda (lambda (logo_FWRT $2 $3 $0)))))) logo_epsA (logo_MULL logo_epsL $2) $0))))) {circleSize} $0))")
|
226 |
+
|
227 |
+
T("(loop i 3 (move (*d 1l 3) (/a 1a 4)))")
|
228 |
+
T("(loop i 5 (move (*d 1l 5) (/a 1a 5)))")
|
229 |
+
T("(loop i infinity (move (*d epsilonDistance 5) (/a epsilonAngle 3)))")
|
230 |
+
T("(loop i infinity (move (*d epsilonDistance 9) (/a epsilonAngle 2)))")
|
231 |
+
T("(loop i infinity (move (*d epsilonLength i) (*a epsilonAngle 3)))")
|
232 |
+
T("(loop i 9 (move (*d 1l i) (/a 1a 4)))")
|
233 |
+
T("(move 1d 0a)")
|
234 |
+
T("(loop i infinity (move (*d epsilonLength 6) epsilonAngle))")
|
235 |
+
T("(loop i infinity (move (*d epsilonLength 8) epsilonAngle))")
|
236 |
+
T("(loop k 2 (loop i infinity (move (*d epsilonLength 4) epsilonAngle)))")
|
237 |
+
T("(loop k 2 (loop i infinity (move (*d epsilonLength 8) epsilonAngle)))")
|
238 |
+
T("(loop s 4 (move (*d 1d 3) (/a 1a 4)))")
|
239 |
+
T("(loop s 4 (move (*d 1d 6) (/a 1a 4)))")
|
240 |
+
T("""
|
241 |
+
(loop j 5
|
242 |
+
(move 0d (/a 1a 5))
|
243 |
+
(embed (loop i infinity
|
244 |
+
(move (*d epsilonLength 6) epsilonAngle))
|
245 |
+
(loop i infinity
|
246 |
+
(move (*d epsilonLength 6) epsilonAngle))))""")
|
247 |
+
T("""
|
248 |
+
(loop j 5
|
249 |
+
(embed (loop s 4 (move (*d 1d 3) (/a 1a 4))))
|
250 |
+
(move 0d (/a 1a 5)))""")
|
251 |
+
return demos
|
252 |
+
|
253 |
+
def rotationalSymmetryDemo():
|
254 |
+
demos = []
|
255 |
+
def T(source):
|
256 |
+
demos.append(manualLogoTask(str(len(demos)), source))
|
257 |
+
|
258 |
+
body = {"dashed": "(p (move 1d 0a)) (move 1d 0a) (p (move 1d 0a)) (move 1d 0a)",
|
259 |
+
"lonely circle": "(p (move (*d 1d 2) 0a)) (loop k 2 (loop i infinity (move (*d epsilonLength 2) epsilonAngle)))",
|
260 |
+
"square dashed": "(p (move 1d 0a)) (loop s 4 (move 1d (/a 1a 4)))",
|
261 |
+
"square": "(loop s 4 (move (*d 1d 2) (/a 1a 4)))",
|
262 |
+
"semicircle": "(loop i infinity (move (*d epsilonLength 4) epsilonAngle))"}
|
263 |
+
for name in body:
|
264 |
+
for n in [3,4,5,6,7]:
|
265 |
+
T("""
|
266 |
+
(loop j %d
|
267 |
+
(embed %s)
|
268 |
+
(move 0d (/a 1a %d)))"""%(n,body[name],n))
|
269 |
+
return demos
|
270 |
+
|
271 |
+
|
272 |
+
def manualLogoTasks():
|
273 |
+
tasks = []
|
274 |
+
def T(name, source, needToTrain=False, supervise=False):
|
275 |
+
tasks.append(manualLogoTask(name, source, supervise=supervise,
|
276 |
+
needToTrain=needToTrain))
|
277 |
+
if False:
|
278 |
+
for d,a,s in [('1l','0a','(loop i infinity (move epsilonLength epsilonAngle))'),
|
279 |
+
('epsilonLength','0a','(loop i infinity (move epsilonLength epsilonAngle))'),
|
280 |
+
('(*d 1l 3)','0a','(move 1l 0a)'),
|
281 |
+
('epsilonLength','0a','(move (*d 1l 2) 0a)'),
|
282 |
+
('(*d epsilonLength 9)','0a','(move epsilonLength 0a)'),
|
283 |
+
('(/d 1l 2)','0a','(move 1l 0a)')]:
|
284 |
+
# 'epsilonLength']:
|
285 |
+
# for a in ['epsilonAngle','0a']:
|
286 |
+
# for s in ['(move 1l 0a)',
|
287 |
+
# '(move epsilonLength 0a)',
|
288 |
+
# '(loop i infinity (move epsilonLength epsilonAngle))']:
|
289 |
+
# if d == 'epsilonLength' and s == '(move epsilonLength 0a)': continue
|
290 |
+
T("pu: %s/%s/%s"%(d,a,s),
|
291 |
+
"""
|
292 |
+
(pu (move %s %s) pd %s)
|
293 |
+
"""%(d,a,s))
|
294 |
+
return tasks
|
295 |
+
|
296 |
+
def slant(n):
|
297 |
+
return f"(move 0d (/a 1a {n}))"
|
298 |
+
|
299 |
+
for n,l,s in [(3,"1l",8),
|
300 |
+
(4,"(*d 1d 3)",None),
|
301 |
+
(5,"1l",None),
|
302 |
+
(6,"(*d 1d 2)",5),
|
303 |
+
(7,"1l",None),
|
304 |
+
(8,"(/d 1d 2)",None)]:
|
305 |
+
T(f"{n}-gon {l}{'' if s is None else ' slanted '+str(s)}",
|
306 |
+
f"""
|
307 |
+
({'' if s is None else slant(s)}
|
308 |
+
(loop i {n}
|
309 |
+
(move {l} (/a 1a {n}))))
|
310 |
+
""",
|
311 |
+
needToTrain=True)
|
312 |
+
for n,l,s in [(3,"(*d 1l 2)",None),
|
313 |
+
(4,"(*d 1d 4)",None),
|
314 |
+
(5,"(*d 1d 2)",None),
|
315 |
+
(6,"1l",None),
|
316 |
+
(7,"(*d 1d 3)",None),
|
317 |
+
(8,"1l",3)]:
|
318 |
+
T(f"{n}-gon {l}{'' if s is None else ' slanted '+str(s)}",
|
319 |
+
f"""
|
320 |
+
({'' if s is None else slant(s)}
|
321 |
+
(loop i {n}
|
322 |
+
(move {l} (/a 1a {n}))))
|
323 |
+
""",
|
324 |
+
needToTrain=False)
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
T("upwards", "((move 0d (/a 1a 4)) (move 1d 0a))",
|
329 |
+
needToTrain=True)
|
330 |
+
T("right angle", "((move (*d 1d 2) (/a 1a 4)) (move 1d 0a))",
|
331 |
+
needToTrain=True)
|
332 |
+
T("right angle epsilon", "((move epsilonLength (/a 1a 4)) (move epsilonLength 0a))",
|
333 |
+
needToTrain=True)
|
334 |
+
|
335 |
+
T("line segment", "(move 1d 0a)",
|
336 |
+
needToTrain=True)
|
337 |
+
|
338 |
+
T("square slanted by 2pi/3",
|
339 |
+
"""((move 0d (/a 1a 3))
|
340 |
+
(loop k 4 (move 1d (/a 1a 4))))""",
|
341 |
+
needToTrain=True)
|
342 |
+
T("semicircle slanted by 2pi/5",
|
343 |
+
"""((move 0d (/a 1a 5))
|
344 |
+
(loop i infinity
|
345 |
+
(move (*d epsilonLength 4) epsilonAngle)))""",
|
346 |
+
needToTrain=True)
|
347 |
+
T("Greek spiral slanted by 2pi/6",
|
348 |
+
"""((move 0d (/a 1a 6))
|
349 |
+
(loop i 7 (move (*l 1l i) (/a 1a 4))))""",
|
350 |
+
needToTrain=True)
|
351 |
+
T("Hook slanted by 2pi/7",
|
352 |
+
"""((move 0d (/a 1a 7))
|
353 |
+
(move 1d 0a)
|
354 |
+
(loop i infinity
|
355 |
+
(move (*d epsilonLength 4) epsilonAngle)))""",
|
356 |
+
needToTrain=True)
|
357 |
+
T("""slanted line""",
|
358 |
+
"""((move 0d (/a 1a 8))
|
359 |
+
(move (*d 1l 3) 0a))""",
|
360 |
+
needToTrain=True)
|
361 |
+
|
362 |
+
|
363 |
+
for i in [6,7,8,9]:
|
364 |
+
T("Greek spiral %d"%i,
|
365 |
+
"""
|
366 |
+
(loop i %d
|
367 |
+
(move (*l 1l i) (/a 1a 4)))
|
368 |
+
"""%i,
|
369 |
+
needToTrain=i in [7,8])
|
370 |
+
for i in [2,3,4,5]:
|
371 |
+
T("smooth spiral %d"%i,
|
372 |
+
"""
|
373 |
+
(loop i infinity
|
374 |
+
(move (*d epsilonLength i) (*a epsilonAngle %d)))
|
375 |
+
"""%i,
|
376 |
+
needToTrain=i in [3,5])
|
377 |
+
|
378 |
+
T("smooth spiral 4 slanted by 2pi/2",
|
379 |
+
"""
|
380 |
+
((move 0d (/a 1a 2))
|
381 |
+
(loop i infinity
|
382 |
+
(move (*d epsilonLength i) (*a epsilonAngle 4))))
|
383 |
+
""",
|
384 |
+
needToTrain=True)
|
385 |
+
|
386 |
+
for i in [3,5,7,9]:
|
387 |
+
T("star %d"%i,
|
388 |
+
"""
|
389 |
+
(loop i %d (move (*d 1d 4) (-a (/a 1a 2) (/a (/a 1a 2) %s))))
|
390 |
+
"""%(i,i),
|
391 |
+
needToTrain=i in [5,9])
|
392 |
+
|
393 |
+
T("leaf iteration 1.1",
|
394 |
+
"""
|
395 |
+
(loop i infinity (move epsilonDistance (/a epsilonAngle 2)))
|
396 |
+
""",
|
397 |
+
needToTrain=True)
|
398 |
+
T("leaf iteration 1.2",
|
399 |
+
"""
|
400 |
+
((move 0d (/a 1a 2))
|
401 |
+
(loop i infinity (move epsilonDistance (/a epsilonAngle 2))))
|
402 |
+
""",
|
403 |
+
needToTrain=True)
|
404 |
+
T("leaf iteration 2.1",
|
405 |
+
"""
|
406 |
+
(loop n 2
|
407 |
+
(loop i infinity (move epsilonDistance (/a epsilonAngle 2)))
|
408 |
+
(move 0d (/a 1a 4)))
|
409 |
+
""",
|
410 |
+
needToTrain=True)
|
411 |
+
T("leaf iteration 2.2",
|
412 |
+
"""
|
413 |
+
((move 0d (/a 1a 2))
|
414 |
+
(loop n 2
|
415 |
+
(loop i infinity (move epsilonDistance (/a epsilonAngle 2)))
|
416 |
+
(move 0d (/a 1a 4))))
|
417 |
+
""",
|
418 |
+
needToTrain=True)
|
419 |
+
for n in range(3,8):
|
420 |
+
T("flower %d"%n,
|
421 |
+
"""
|
422 |
+
(loop j %d
|
423 |
+
(loop n 2
|
424 |
+
(loop i infinity (move epsilonDistance (/a epsilonAngle 2)))
|
425 |
+
(move 0d (/a 1a 4)))
|
426 |
+
(move 0d (/a 1a %d)))
|
427 |
+
"""%(n,n),
|
428 |
+
needToTrain=n in range(3,5))
|
429 |
+
|
430 |
+
for n in [5,6]:
|
431 |
+
T("staircase %d"%n,
|
432 |
+
"""
|
433 |
+
(loop i %d
|
434 |
+
(move 1d (/a 1a 4))
|
435 |
+
(move 1d (/a 1a 4))
|
436 |
+
(move 0d (/a 1a 2)))
|
437 |
+
"""%n,
|
438 |
+
needToTrain=n in [5])
|
439 |
+
|
440 |
+
for n in range(1,6):
|
441 |
+
T("blocks zigzag %d"%n,
|
442 |
+
"""
|
443 |
+
(loop i %d
|
444 |
+
(move 1d (/a 1a 4)) (move 1d (/a 1a 4))
|
445 |
+
(move 1d (+a (/a 1a 2) (/a 1a 4))) (move 1d (+a (/a 1a 2) (/a 1a 4))))
|
446 |
+
"""%n,
|
447 |
+
needToTrain=n in [1,2,3])
|
448 |
+
for n in [3,4]:#range(1,5):
|
449 |
+
T("diagonal zigzag %d"%n,
|
450 |
+
"""
|
451 |
+
((move 0d (/a 1a 8))
|
452 |
+
(loop i %d
|
453 |
+
(move 1d (/a 1a 4))
|
454 |
+
(move 1d (+a (/a 1a 2) (/a 1a 4)))))
|
455 |
+
"""%n,
|
456 |
+
needToTrain=n == 4)
|
457 |
+
|
458 |
+
|
459 |
+
|
460 |
+
for n in [1,2,3,4,5,6]:
|
461 |
+
T("right semicircle of size %d"%n,
|
462 |
+
"""
|
463 |
+
(loop i infinity
|
464 |
+
(move (*d epsilonLength %d) (-a 0a epsilonAngle)))
|
465 |
+
"""%n,
|
466 |
+
needToTrain=n%2 == 0)
|
467 |
+
T("left semicircle of size %d"%n,
|
468 |
+
f"""
|
469 |
+
({'' if n != 1 else slant(8)}
|
470 |
+
(loop i infinity
|
471 |
+
(move (*d epsilonLength {n}) epsilonAngle)))
|
472 |
+
""",
|
473 |
+
needToTrain=n%2 == 1)
|
474 |
+
T("circle of size %d"%n,
|
475 |
+
"""
|
476 |
+
((loop i infinity
|
477 |
+
(move (*d epsilonLength %d) epsilonAngle))
|
478 |
+
(loop i infinity
|
479 |
+
(move (*d epsilonLength %d) epsilonAngle)))
|
480 |
+
"""%(n,n),
|
481 |
+
needToTrain=n in [1,4,3,5,6])
|
482 |
+
|
483 |
+
for n in [5,6]:
|
484 |
+
T("%d enclosed circles"%n,
|
485 |
+
"""
|
486 |
+
(loop j %d
|
487 |
+
(loop i infinity
|
488 |
+
(move (*d epsilonLength j) epsilonAngle))
|
489 |
+
(loop i infinity
|
490 |
+
(move (*d epsilonLength j) epsilonAngle)))"""%n,
|
491 |
+
needToTrain=n == 5)
|
492 |
+
|
493 |
+
for n,l in [(4,2),
|
494 |
+
(5,3),
|
495 |
+
(6,4),
|
496 |
+
(3,1)]:
|
497 |
+
T("%d-circle flower l=%d"%(n,l),
|
498 |
+
"""
|
499 |
+
(loop j %d
|
500 |
+
(move 0d (/a 1a %d))
|
501 |
+
(embed (loop i infinity
|
502 |
+
(move (*d epsilonLength %d) epsilonAngle))
|
503 |
+
(loop i infinity
|
504 |
+
(move (*d epsilonLength %d) epsilonAngle))))"""%(n,n,l,l),
|
505 |
+
needToTrain=(n,l) in [(6,4),(3,1)])
|
506 |
+
|
507 |
+
for n,l in [(3,1),(2,2),(1,3),
|
508 |
+
(2,1),(1,2),(1,1)]:
|
509 |
+
T("%d-semicircle sequence L=%d"%(n,l),
|
510 |
+
"""
|
511 |
+
(loop j %d
|
512 |
+
(loop i infinity
|
513 |
+
(move (*d epsilonLength %d) epsilonAngle))
|
514 |
+
(loop i infinity
|
515 |
+
(move (*d epsilonLength %d) (-a 0a epsilonAngle))))
|
516 |
+
"""%(n,l,l),
|
517 |
+
needToTrain=(n,l) in [(3,1),(2,2),(1,3)])
|
518 |
+
|
519 |
+
for n,l in [(2,"1d"),
|
520 |
+
(3,"1d")]:
|
521 |
+
T("row of %d circles"%n,
|
522 |
+
"""
|
523 |
+
(loop j %d
|
524 |
+
(embed (loop k 2 (loop i infinity (move epsilonLength epsilonAngle))))
|
525 |
+
(p (move %s 0a)))"""%(n,l),
|
526 |
+
needToTrain=n == 2)
|
527 |
+
for n,l in [(2,"1d"),
|
528 |
+
(3,"1d")]:
|
529 |
+
T("row of %d lines"%n,
|
530 |
+
"""
|
531 |
+
(loop j %d
|
532 |
+
(move 1d 0a)
|
533 |
+
(p (move %s 0a)))"""%(n,l),
|
534 |
+
needToTrain=n == 2)
|
535 |
+
T("line next to semicircle",
|
536 |
+
"""
|
537 |
+
((move 1d 0a) (p (move 1d 0a)) (loop i infinity (move epsilonLength epsilonAngle)))
|
538 |
+
""",
|
539 |
+
needToTrain=True)
|
540 |
+
for n,l in [(3,"(/d 1d 2)"),
|
541 |
+
(4,"(/d 1d 3)")]:
|
542 |
+
T("%d dashed lines of size %s"%(n,l),
|
543 |
+
"""(loop i %d (p (move 1d 0a)) (move %s 0a))"""%(n,l),
|
544 |
+
needToTrain=n == 3)
|
545 |
+
T("broken circle",
|
546 |
+
"""
|
547 |
+
((loop i infinity (move epsilonLength epsilonAngle)) (p (move 1d 0a)) (loop i infinity (move epsilonLength epsilonAngle)))
|
548 |
+
""",
|
549 |
+
needToTrain=True)
|
550 |
+
T("circle next to semicircle",
|
551 |
+
"""
|
552 |
+
((loop i infinity (move epsilonLength epsilonAngle))
|
553 |
+
(loop i infinity (move epsilonLength epsilonAngle))
|
554 |
+
(p (move 1d 0a))
|
555 |
+
(loop i infinity (move epsilonLength epsilonAngle)))
|
556 |
+
""",
|
557 |
+
needToTrain=True)
|
558 |
+
T("semicircle next to square",
|
559 |
+
"""
|
560 |
+
((loop i infinity (move epsilonLength epsilonAngle))
|
561 |
+
(p (move 1d 0a))
|
562 |
+
(loop i infinity (move 1d (/a 1a 4))))
|
563 |
+
""",
|
564 |
+
needToTrain=False)
|
565 |
+
T("circle next to square",
|
566 |
+
"""
|
567 |
+
((loop i infinity (move epsilonLength epsilonAngle))
|
568 |
+
(loop i infinity (move epsilonLength epsilonAngle))
|
569 |
+
(p (move 1d 0a))
|
570 |
+
(loop i infinity (move 1d (/a 1a 4))))
|
571 |
+
""",
|
572 |
+
needToTrain=False)
|
573 |
+
T("circle next to line",
|
574 |
+
"""
|
575 |
+
((loop i infinity (move epsilonLength epsilonAngle))
|
576 |
+
(loop i infinity (move epsilonLength epsilonAngle))
|
577 |
+
(p (move 1d 0a))
|
578 |
+
(move 1d 0a))
|
579 |
+
""",
|
580 |
+
needToTrain=True)
|
581 |
+
T("line next to circle",
|
582 |
+
"""
|
583 |
+
((move 1d 0a)
|
584 |
+
(p (move 1d 0a))
|
585 |
+
(loop i infinity (move epsilonLength epsilonAngle))
|
586 |
+
(loop i infinity (move epsilonLength epsilonAngle))
|
587 |
+
(move 1d 0a))
|
588 |
+
""",
|
589 |
+
needToTrain=True)
|
590 |
+
for n,l in [(4,"1d"),
|
591 |
+
(5,"1d")]:
|
592 |
+
T("row of %d dashes"%n,
|
593 |
+
"""
|
594 |
+
(loop j %d
|
595 |
+
(embed (move 0d (/a 1a 4)) (move 1d 0a))
|
596 |
+
(p (move %s 0a)))"""%(n,l),
|
597 |
+
needToTrain=n == 4)
|
598 |
+
for n,l in [(5,"1d"),(6,"1d")]:
|
599 |
+
T("row of %d semicircles"%n,
|
600 |
+
"""
|
601 |
+
(loop j %d
|
602 |
+
(embed (loop i infinity (move epsilonLength epsilonAngle)))
|
603 |
+
(p (move %s 0a)))"""%(n,l),
|
604 |
+
needToTrain=n == 5)
|
605 |
+
|
606 |
+
with random_seed(42): # carefully selected for maximum entropy
|
607 |
+
for n in [3,4,5,6,7]:
|
608 |
+
body = {"empty": "(move 1d 0a)",
|
609 |
+
"spiral": "(loop i infinity (move (*d epsilonLength i) (*a epsilonAngle 2)))",
|
610 |
+
"dashed": "(p (move 1d 0a)) (move 1d 0a)",
|
611 |
+
"circle": "(move 1d 0a) (loop k 2 (loop i infinity (move epsilonLength epsilonAngle)))",
|
612 |
+
"lonely circle": "(p (move 1d 0a)) (loop k 2 (loop i infinity (move epsilonLength epsilonAngle)))",
|
613 |
+
"square dashed": "(p (move 1d 0a)) (loop s 4 (move 1d (/a 1a 4)))",
|
614 |
+
"square": "(move 1d 0a) (loop s 4 (move 1d (/a 1a 4)))",
|
615 |
+
"close large semicircle": "(loop i infinity (move (*d epsilonLength 2) epsilonAngle))",
|
616 |
+
"close semicircle": "(loop i infinity (move epsilonLength epsilonAngle))",
|
617 |
+
"semicircle": "(move 1d 0a) (loop i infinity (move epsilonLength epsilonAngle))",
|
618 |
+
"double dashed": "(p (move 1d 0a)) (move 1d 0a) (p (move 1d 0a)) (move 1d 0a)",
|
619 |
+
"Greek": "(loop i 3 (move (*l 1l i) (/a 1a 4)))"}
|
620 |
+
for name in body:
|
621 |
+
if name == "spiral" and n not in [3,5]: continue
|
622 |
+
if name == "square" and n not in [5,3,6,7]: continue
|
623 |
+
if name == "semicircle" and n not in [5,3,4,6]: continue
|
624 |
+
if name == "Greek" and n not in [3,5]: continue
|
625 |
+
if name == "double dashed" and n not in [6,4,3]: continue
|
626 |
+
|
627 |
+
mustTrain = False
|
628 |
+
|
629 |
+
mustTrain = mustTrain or (n == 3 and name == "Greek")
|
630 |
+
mustTrain = mustTrain or (n == 7 and name == "empty")
|
631 |
+
mustTrain = mustTrain or (n == 5 and name == "dashed")
|
632 |
+
mustTrain = mustTrain or (n == 7 and name == "circle")
|
633 |
+
mustTrain = mustTrain or (n == 6 and name == "circle")
|
634 |
+
mustTrain = mustTrain or (n == 6 and name == "lonely circle")
|
635 |
+
mustTrain = mustTrain or (n == 5 and name == "square")
|
636 |
+
mustTrain = mustTrain or (n == 7 and name == "square")
|
637 |
+
mustTrain = mustTrain or (n == 5 and name == "semicircle")
|
638 |
+
mustTrain = mustTrain or (n == 3 and name == "square dashed")
|
639 |
+
mustTrain = mustTrain or (n == 6 and name == "close semicircle")
|
640 |
+
mustTrain = mustTrain or (n == 5 and name == "close large semicircle")
|
641 |
+
mustTrain = mustTrain or (n == 3 and name == "spiral")
|
642 |
+
mustTrain = mustTrain or (n == 6 and name == "double dashed")
|
643 |
+
mustTrain = mustTrain or (n == 3 and name == "double dashed")
|
644 |
+
#mustTrain = mustTrain or (n == 6 and name == "empty")
|
645 |
+
|
646 |
+
#mustTrain = mustTrain or (random.random() < 0.07) # calibrated to give 70 training tasks
|
647 |
+
|
648 |
+
|
649 |
+
# # cap number of super easy snowflakes
|
650 |
+
# if name == "empty" and n not in [7]: mustTrain = False
|
651 |
+
# if name == "dashed" and n not in [4]: mustTrain = False
|
652 |
+
|
653 |
+
|
654 |
+
T("%d-%s snowflake"%(n,name),
|
655 |
+
"""
|
656 |
+
(loop j %d
|
657 |
+
(embed %s)
|
658 |
+
(move 0d (/a 1a %d)))"""%(n,body[name],n),
|
659 |
+
needToTrain=mustTrain)
|
660 |
+
|
661 |
+
for n in [3,4]:#2,3,4]:
|
662 |
+
T("%d-row of squares"%n,
|
663 |
+
"""
|
664 |
+
(loop i %d
|
665 |
+
(embed (loop k 4 (move 1d (/a 1a 4))))
|
666 |
+
(move 1d 0a))
|
667 |
+
"""%n,
|
668 |
+
needToTrain=n == 4)
|
669 |
+
T("2x2 grid",
|
670 |
+
"""
|
671 |
+
(for x 2 (embed (for y 2
|
672 |
+
(embed (loop k 4 (move 1d (/a 1a 4))))
|
673 |
+
(move 1d 0a)))
|
674 |
+
(move 0d (/a 1a 4)) (move 1d (-a 0a (/a 1a 4))))
|
675 |
+
""")
|
676 |
+
T("slanted squares",
|
677 |
+
"""
|
678 |
+
((embed (loop k 4 (move 1d (/a 1a 4))))
|
679 |
+
(move 0d (/a 1a 8))
|
680 |
+
(loop k 4 (move 1d (/a 1a 4))))
|
681 |
+
""")
|
682 |
+
for l in range(1,6):
|
683 |
+
T("square of size %d"%l,
|
684 |
+
"""
|
685 |
+
(for i 4
|
686 |
+
(move (*d 1d %d) (/a 1a 4)))
|
687 |
+
"""%l,
|
688 |
+
needToTrain=l in range(4))
|
689 |
+
for n in [5,7]:
|
690 |
+
T("%d-concentric squares"%n,
|
691 |
+
"""
|
692 |
+
(for i %d
|
693 |
+
(embed (loop j 4 (move (*d 1d i) (/a 1a 4)))))
|
694 |
+
"""%n,
|
695 |
+
needToTrain=n == 5)
|
696 |
+
return tasks
|
697 |
+
|
698 |
+
def montageTasks(tasks, prefix="", columns=None, testTrain=False):
|
699 |
+
import numpy as np
|
700 |
+
|
701 |
+
w = 128
|
702 |
+
arrays = [t.highresolution for t in tasks]
|
703 |
+
for a in arrays:
|
704 |
+
assert len(a) == w*w
|
705 |
+
|
706 |
+
if testTrain:
|
707 |
+
arrays = [a for a,t in zip(arrays, tasks) if t.mustTrain ] + [a for a,t in zip(arrays, tasks) if not t.mustTrain ]
|
708 |
+
|
709 |
+
arrays = [np.array([a[i:i + w]
|
710 |
+
for i in range(0, len(a), w) ])
|
711 |
+
for a in arrays]
|
712 |
+
i = montage(arrays, columns=columns)
|
713 |
+
|
714 |
+
import scipy.misc
|
715 |
+
scipy.misc.imsave('/tmp/%smontage.png'%prefix, i)
|
716 |
+
if testTrain:
|
717 |
+
trainingTasks = arrays[:sum(t.mustTrain for t in tasks)]
|
718 |
+
testingTasks = arrays[sum(t.mustTrain for t in tasks):]
|
719 |
+
random.shuffle(trainingTasks)
|
720 |
+
random.shuffle(testingTasks)
|
721 |
+
arrays = trainingTasks + testingTasks
|
722 |
+
else:
|
723 |
+
random.shuffle(arrays)
|
724 |
+
scipy.misc.imsave('/tmp/%srandomMontage.png'%prefix, montage(arrays, columns=columns))
|
725 |
+
|
726 |
+
def demoLogoTasks():
|
727 |
+
import scipy.misc
|
728 |
+
import numpy as np
|
729 |
+
|
730 |
+
g0 = Grammar.uniform(primitives, continuationType=turtle)
|
731 |
+
eprint("dreaming into /tmp/dreams_0...")
|
732 |
+
N = 1000
|
733 |
+
programs = [ p
|
734 |
+
for _ in range(N)
|
735 |
+
for p in [g0.sample(arrow(turtle,turtle),
|
736 |
+
maximumDepth=20)]
|
737 |
+
if p is not None]
|
738 |
+
os.system("mkdir -p /tmp/dreams_0")
|
739 |
+
for n,p in enumerate(programs):
|
740 |
+
with open(f"/tmp/dreams_0/{n}.dream","w") as handle:
|
741 |
+
handle.write(str(p))
|
742 |
+
drawLogo(*programs, pretty=True, smoothPretty=False,
|
743 |
+
resolution=512,
|
744 |
+
filenames=[f"/tmp/dreams_0/{n}_pretty.png"
|
745 |
+
for n in range(len(programs)) ],
|
746 |
+
timeout=1)
|
747 |
+
|
748 |
+
if len(sys.argv) > 1:
|
749 |
+
tasks = makeTasks(sys.argv[1:],proto=False)
|
750 |
+
else:
|
751 |
+
tasks = makeTasks(['all'],proto=False)
|
752 |
+
montageTasks(tasks,columns=16,testTrain=True)
|
753 |
+
for n,t in enumerate(tasks):
|
754 |
+
a = t.highresolution
|
755 |
+
w = int(len(a)**0.5)
|
756 |
+
scipy.misc.imsave('/tmp/logo%d.png'%n, np.array([a[i:i+w]
|
757 |
+
for i in range(0,len(a),w) ]))
|
758 |
+
logo_safe_name = t.name.replace("=","_").replace(' ','_').replace('/','_').replace("-","_") + ".png"
|
759 |
+
#os.system(f"convert /tmp/logo{n}.png -morphology Dilate Octagon /tmp/{logo_safe_name}")
|
760 |
+
os.system(f"convert /tmp/logo{n}.png -channel RGB -negate /tmp/{logo_safe_name}")
|
761 |
+
eprint(len(tasks),"tasks")
|
762 |
+
eprint(sum(t.mustTrain for t in tasks),"need to be trained on")
|
763 |
+
|
764 |
+
for t in dSLDemo():
|
765 |
+
a = t.highresolution
|
766 |
+
w = int(len(a)**0.5)
|
767 |
+
scipy.misc.imsave('/tmp/logoDemo%s.png'%t.name, np.array([a[i:i+w]
|
768 |
+
for i in range(0,len(a),w) ]))
|
769 |
+
os.system(f"convert /tmp/logoDemo{t.name}.png -morphology Dilate Octagon /tmp/logoDemo{t.name}_dilated.png")
|
770 |
+
|
771 |
+
tasks = [t for t in tasks if t.mustTrain ]
|
772 |
+
random.shuffle(tasks)
|
773 |
+
montageTasks(tasks[:16*3],"subset",columns=16)
|
774 |
+
|
775 |
+
montageTasks(rotationalSymmetryDemo(),"rotational")
|
776 |
+
|
777 |
+
|
dreamcoder/domains/misc/RobustFillPrimitives.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#RobustFillPrimitives
|
2 |
+
|
3 |
+
from dreamcoder.program import Primitive, prettyProgram
|
4 |
+
from dreamcoder.grammar import Grammar
|
5 |
+
from dreamcoder.type import tint, arrow, baseType #, t0, t1, t2
|
6 |
+
|
7 |
+
from string import printable
|
8 |
+
import re
|
9 |
+
from collections import defaultdict
|
10 |
+
|
11 |
+
#from functools import reduce
|
12 |
+
|
13 |
+
|
14 |
+
disallowed = [
|
15 |
+
("#", "hash"),
|
16 |
+
("!", "bang"),
|
17 |
+
("\"", "double_quote"),
|
18 |
+
("$", "dollar"),
|
19 |
+
("%", "percent"),
|
20 |
+
("&", "ampersand"),
|
21 |
+
("'", "single_quote"),
|
22 |
+
(")", "left_paren"),
|
23 |
+
("(", "right_paren"),
|
24 |
+
("*", "astrisk"),
|
25 |
+
("+", "plus"),
|
26 |
+
(",", "comma"),
|
27 |
+
("-", "dash"),
|
28 |
+
(".", "period"),
|
29 |
+
("/", "slash"),
|
30 |
+
(":", "colon"),
|
31 |
+
(";", "semicolon"),
|
32 |
+
("<", "less_than"),
|
33 |
+
("=", "equal"),
|
34 |
+
(">", "greater_than"),
|
35 |
+
("?", "question_mark"),
|
36 |
+
("@", "at"),
|
37 |
+
("[", "left_bracket"),
|
38 |
+
("\\", "backslash"),
|
39 |
+
("]", "right_bracket"),
|
40 |
+
("^", "carrot"),
|
41 |
+
("_", "underscore"),
|
42 |
+
("`", "backtick"),
|
43 |
+
("|", "bar"),
|
44 |
+
("}", "right_brace"),
|
45 |
+
("{", "left_brace"),
|
46 |
+
("~", "tilde"),
|
47 |
+
(" ", "space"),
|
48 |
+
("\t", "tab")
|
49 |
+
]
|
50 |
+
disallowed = dict(disallowed)
|
51 |
+
delimiters = "&,.?!@()[]%{/}:;$#\"'"
|
52 |
+
|
53 |
+
delim_dict = {disallowed[c]:c for c in delimiters}
|
54 |
+
|
55 |
+
types = {}
|
56 |
+
types["Number"] = r'\d+'
|
57 |
+
types["Word"] = r'\w+'
|
58 |
+
types["Alphanum"] = r'\w'
|
59 |
+
types["PropCase"] = r'[A-Z][a-z]+'
|
60 |
+
types["AllCaps"] = r'[A-Z]'
|
61 |
+
types["Lower"] = r'[a-z]'
|
62 |
+
types["Digit"] = r'\d'
|
63 |
+
types["Char"] = r'.'
|
64 |
+
|
65 |
+
regexes = {name: re.escape(val) for name, val in delim_dict.items()}
|
66 |
+
regexes = {**regexes, **types}
|
67 |
+
|
68 |
+
tposition = baseType("position")
|
69 |
+
tindex = baseType("index")
|
70 |
+
tcharacter = baseType("character")
|
71 |
+
tboundary = baseType("boundary")
|
72 |
+
tregex = baseType("regex")
|
73 |
+
tsubstr = baseType("substr")
|
74 |
+
texpression = baseType("expression")
|
75 |
+
tprogram = baseType("program")
|
76 |
+
tnesting = baseType("nesting")
|
77 |
+
ttype = baseType("type")
|
78 |
+
tdelimiter = baseType("delimiter")
|
79 |
+
|
80 |
+
def _substr(k1): return lambda k2: lambda string: string[k1:k2] #i think this is fine
|
81 |
+
def _getspan(r1):
|
82 |
+
return lambda i1: lambda b1: lambda r2: lambda i2: lambda b2: lambda string: \
|
83 |
+
string[
|
84 |
+
[m.end() for m in re.finditer(r1, string)][i1] if b1 == "End" else [m.start() for m in re.finditer(r1, string)][i1]:[m.end() for m in re.finditer(r2, string)][i2] if b2 == "End" else [m.start() for m in re.finditer(r2, string)][i2]
|
85 |
+
]
|
86 |
+
#TODO format correctly
|
87 |
+
def _getspan_const(r1): return lambda i1: lambda b1: lambda r2: lambda i2: lambda b2: (defaultdict(int, {r1:i1+1 if i1>=0 else abs(i1), r2:i2+1 if i2>=0 else abs(i2)}), max(i1+1 if i1>=0 else abs(i1), i2+1 if i2>=0 else abs(i2)))
|
88 |
+
|
89 |
+
|
90 |
+
def _trim(string):
|
91 |
+
assert False
|
92 |
+
return string
|
93 |
+
|
94 |
+
def _replace(d1, d2): return lambda string: string.replace(d1,d2)
|
95 |
+
|
96 |
+
def _getall(tp): return lambda string: ''.join(re.findall(tp, string))
|
97 |
+
def _getfirst(tp, i): return lambda string: ''.join(re.findall(tp, string)[:i])
|
98 |
+
def _gettoken(tp, i): return lambda string: re.findall(tp, string)[i]
|
99 |
+
def _gettoken_const(tp, i): return defaultdict(int, {tp: i+1 if i>=0 else abs(i)}), i+1 if i>=0 else abs(i)
|
100 |
+
|
101 |
+
def _getupto(reg): return lambda string: string[:[m.end() for m in re.finditer(reg, string)][0]]
|
102 |
+
def _getfrom(reg): return lambda string: string[[m.end() for m in re.finditer(reg, string)][-1]:]
|
103 |
+
|
104 |
+
def _concat2(expr1): return lambda expr2: lambda string: expr1(string) + expr2(string) #More concats plz
|
105 |
+
def _concat1(expr): return lambda string: expr(string)
|
106 |
+
def _concat_list(expr): return lambda program: lambda string: expr(string) + program(string)
|
107 |
+
#i've decided that all of the things which are expressions should take tstring as last input and output a tstring. Thus, all requests are arrow(tstring, tstring) and we limit size with recursive depth
|
108 |
+
"""
|
109 |
+
todo:
|
110 |
+
- _trim
|
111 |
+
- incorporate tcharacter
|
112 |
+
- constraints
|
113 |
+
- format _getspan
|
114 |
+
- figure out how to represent on top_level
|
115 |
+
|
116 |
+
- flatten for nn
|
117 |
+
- parse
|
118 |
+
|
119 |
+
- robustfill_util
|
120 |
+
- train dc model for robustfill
|
121 |
+
- main_supervised_robustfill
|
122 |
+
- evaluate_robustfill
|
123 |
+
- sample_data
|
124 |
+
|
125 |
+
|
126 |
+
- deal with escapes ...
|
127 |
+
|
128 |
+
constraints:
|
129 |
+
elements, and number necessary, and lengths
|
130 |
+
"""
|
131 |
+
|
132 |
+
def robustFillPrimitives(max_len=100, max_index=5):
|
133 |
+
return [
|
134 |
+
#CPrimitive("concat2", arrow(texpression, texpression, tprogram), _concat2),
|
135 |
+
CPrimitive("concat1", arrow(texpression, tprogram), _concat1),
|
136 |
+
CPrimitive("concat_list", arrow(texpression, tprogram, tprogram), _concat_list),
|
137 |
+
#expressions
|
138 |
+
CPrimitive("Constant", arrow(tcharacter, texpression), lambda x: lambda y: x), # add a constraint
|
139 |
+
CPrimitive("apply", arrow(tnesting, tsubstr, texpression), lambda n: lambda sub: lambda string: n(sub(string))),
|
140 |
+
CPrimitive("apply_n", arrow(tnesting, tnesting, texpression), lambda n1: lambda n2: lambda string: n1(n2(string))),
|
141 |
+
CPrimitive("expr_n", arrow(tnesting, texpression), lambda x: x),
|
142 |
+
CPrimitive("expr_f", arrow(tsubstr, texpression), lambda x: x)
|
143 |
+
] + [
|
144 |
+
#substrings
|
145 |
+
CPrimitive("SubStr", arrow(tposition, tposition, tsubstr), _substr), # handled
|
146 |
+
CPrimitive("GetSpan", arrow(tregex, tindex, tboundary, tregex, tindex, tboundary, tsubstr), _getspan, _getspan_const) #TODO constraint
|
147 |
+
] + [
|
148 |
+
#nestings
|
149 |
+
CPrimitive("GetToken"+name+str(i), tnesting, _gettoken(tp,i), _gettoken_const(tp, i)) for name, tp in types.items() for i in range(-max_index, max_index)
|
150 |
+
] + [
|
151 |
+
CPrimitive("ToCase_ProperCase", tnesting, lambda x: x.title(), (defaultdict(int, {r'[A-Z][a-z]+':1}), 1)),
|
152 |
+
CPrimitive("ToCase_AllCapsCase", tnesting, lambda x: x.upper(), (defaultdict(int, {r'[A-Z]':1}) ,1)),
|
153 |
+
CPrimitive("ToCase_LowerCase", tnesting, lambda x: x.lower(), (defaultdict(int, {r'[a-z]':1}), 1) )
|
154 |
+
] + [
|
155 |
+
CPrimitive("Replace_"+name1+name2, tnesting, _replace(char1, char2), (defaultdict(int, {char1:1}), 1)) for name1, char1 in delim_dict.items() for name2, char2 in delim_dict.items() if char1 is not char2
|
156 |
+
] + [
|
157 |
+
#CPrimitive("Trim", tnesting, _trim), #TODO
|
158 |
+
] + [
|
159 |
+
CPrimitive("GetUpTo"+name, tnesting, _getupto(reg), (defaultdict(int, {reg:1} ),1)) for name, reg in regexes.items()
|
160 |
+
] + [
|
161 |
+
CPrimitive("GetFrom"+name, tnesting, _getfrom(reg), (defaultdict(int, {reg:1} ),1)) for name, reg in regexes.items()
|
162 |
+
] + [
|
163 |
+
CPrimitive("GetFirst_"+name+str(i), tnesting, _getfirst(tp, i), (defaultdict(int, {tp:i} ), i+1 if i>=0 else abs(i))) for name, tp in types.items() for i in list(range(-max_index,0))+ list(range(1,max_index+1))
|
164 |
+
] + [
|
165 |
+
CPrimitive("GetAll_"+name, tnesting, _getall(reg),(defaultdict(int, {reg:1} ),1) ) for name, reg in types.items()
|
166 |
+
] + [
|
167 |
+
#regexes
|
168 |
+
CPrimitive("type_to_regex", arrow(ttype, tregex), lambda x: x), #TODO also make disappear
|
169 |
+
CPrimitive("delimiter_to_regex", arrow(tdelimiter, tregex), lambda x: re.escape(x)) #TODO also make disappear
|
170 |
+
] + [
|
171 |
+
#types
|
172 |
+
CPrimitive("Number", ttype, r'\d+', r'\d+'), #TODO
|
173 |
+
CPrimitive("Word", ttype, r'\w+', r'\w+'), #TODO
|
174 |
+
CPrimitive("Alphanum", ttype, r'\w', r'\w'), #TODO
|
175 |
+
CPrimitive("PropCase", ttype, r'[A-Z][a-z]+', r'[A-Z][a-z]+'), #TODO
|
176 |
+
CPrimitive("AllCaps", ttype, r'[A-Z]', r'[A-Z]'), #TODO
|
177 |
+
CPrimitive("Lower", ttype, r'[a-z]', r'[a-z]'), #TODO
|
178 |
+
CPrimitive("Digit", ttype, r'\d', r'\d'), #TODO
|
179 |
+
CPrimitive("Char", ttype, r'.', r'.') #TODO
|
180 |
+
] + [
|
181 |
+
#Cases
|
182 |
+
# CPrimitive("ProperCase", tcase, .title()), #TODO
|
183 |
+
# CPrimitive("AllCapsCase", tcase, .upper()), #TODO
|
184 |
+
# CPrimitive("LowerCase", tcase, .lower()) #TODO
|
185 |
+
] + [
|
186 |
+
#positions
|
187 |
+
CPrimitive("position"+str(i), tposition, i, (defaultdict(int), i+1 if i>=0 else abs(i)) ) for i in range(-max_len,max_len+1) #deal with indicies
|
188 |
+
] + [
|
189 |
+
#indices
|
190 |
+
CPrimitive("index"+str(i), tindex, i, i) for i in range(-max_index,max_index+1) #deal with indicies
|
191 |
+
] + [
|
192 |
+
#characters
|
193 |
+
CPrimitive(i, tcharacter, i, (defaultdict(int, {i:1}),1) ) for i in printable[:-5] if i not in disallowed
|
194 |
+
] + [
|
195 |
+
CPrimitive(name, tcharacter, char, (defaultdict(int, {char:1}), 1)) for char, name in disallowed.items() # NB: disallowed is reversed
|
196 |
+
] + [
|
197 |
+
#delimiters
|
198 |
+
CPrimitive("delim_"+name, tdelimiter, char, char) for name, char in delim_dict.items()
|
199 |
+
] + [
|
200 |
+
#boundaries
|
201 |
+
CPrimitive("End", tboundary, "End"),
|
202 |
+
CPrimitive("Start", tboundary, "Start")
|
203 |
+
]
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
def RobustFillProductions(max_len=100, max_index=5):
|
208 |
+
return [(0.0, prim) for prim in robustFillPrimitives(max_len=max_len, max_index=max_index)]
|
209 |
+
|
210 |
+
|
211 |
+
def flatten_program(p):
|
212 |
+
string = p.show(False)
|
213 |
+
string = string.replace('(', '')
|
214 |
+
string = string.replace(')', '')
|
215 |
+
#remove '_fn' (optional)
|
216 |
+
string = string.split(' ')
|
217 |
+
string = list(filter(lambda x: x is not '', string))
|
218 |
+
return string
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
def add_constraints(c1, c2=None):
|
224 |
+
if c2 is None:
|
225 |
+
return c1
|
226 |
+
d1, m1 = c1
|
227 |
+
d2, m2 = c2
|
228 |
+
min_size = max(m1, m2)
|
229 |
+
d = defaultdict(int)
|
230 |
+
for item in set(d1.keys()) | set(d2.keys()):
|
231 |
+
d[item] = max(d1[item], d2[item])
|
232 |
+
return d, min_size
|
233 |
+
|
234 |
+
# class Constraint_prop:
|
235 |
+
# def application(self, p, environment):
|
236 |
+
# self.f.visit(self, environment)(self.x.visit(self, environment))
|
237 |
+
# def primitive(self, p, environment):
|
238 |
+
# return self.value
|
239 |
+
|
240 |
+
class Constraint_prop:
|
241 |
+
def __init__(self):
|
242 |
+
pass
|
243 |
+
|
244 |
+
def application(self, p):
|
245 |
+
return p.f.visit(self)(p.x.visit(self))
|
246 |
+
|
247 |
+
def primitive(self, p):
|
248 |
+
return p.constraint
|
249 |
+
|
250 |
+
def execute(self, p):
|
251 |
+
return p.visit(self)
|
252 |
+
|
253 |
+
|
254 |
+
class CPrimitive(Primitive):
|
255 |
+
def __init__(self, name, ty, value, constraint=None):
|
256 |
+
#I have no idea why this works but it does .....
|
257 |
+
if constraint is None:
|
258 |
+
if len(ty.functionArguments())==0:
|
259 |
+
self.constraint = (defaultdict(int), 0)
|
260 |
+
elif len(ty.functionArguments())==1:
|
261 |
+
self.constraint = lambda x: x
|
262 |
+
elif len(ty.functionArguments())==2:
|
263 |
+
self.constraint = lambda x: lambda y: add_constraints(x,y)
|
264 |
+
else:
|
265 |
+
self.constraint = lambda x: x
|
266 |
+
for _ in range(len(ty.functionArguments()) - 1):
|
267 |
+
self.constraint = lambda x: lambda y: add_constraints(x, self.constraint(y))
|
268 |
+
else: self.constraint = constraint
|
269 |
+
super(CPrimitive, self).__init__(name, ty, value)
|
270 |
+
|
271 |
+
#def __getinitargs__(self):
|
272 |
+
# return (self.name, self.tp, self.value, None)
|
273 |
+
|
274 |
+
def __getstate__(self):
|
275 |
+
#print("self.name", self.name)
|
276 |
+
return self.name
|
277 |
+
|
278 |
+
def __setstate__(self, state):
|
279 |
+
#for backwards compatibility:
|
280 |
+
if type(state) == dict:
|
281 |
+
pass #do nothing, i don't need to load them if they are old...
|
282 |
+
else:
|
283 |
+
p = Primitive.GLOBALS[state]
|
284 |
+
self.__init__(p.name, p.tp, p.value, p.constraint)
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
if __name__=='__main__':
|
289 |
+
import time
|
290 |
+
CPrimitive("testCPrim", tint, lambda x: x, 17)
|
291 |
+
g = Grammar.fromProductions(RobustFillProductions())
|
292 |
+
print(len(g))
|
293 |
+
request = tprogram
|
294 |
+
p = g.sample(request)
|
295 |
+
print("request:", request)
|
296 |
+
print("program:")
|
297 |
+
print(prettyProgram(p))
|
298 |
+
s = 'abcdefg'
|
299 |
+
e = p.evaluate([])
|
300 |
+
#print("prog applied to", s)
|
301 |
+
#print(e(s))
|
302 |
+
print("flattened_program:")
|
303 |
+
flat = flatten_program(p)
|
304 |
+
print(flat)
|
305 |
+
t = time.time()
|
306 |
+
constraints = Constraint_prop().execute(p)
|
307 |
+
print(time.time() - t)
|
308 |
+
print(constraints)
|
dreamcoder/domains/misc/__init__.py
ADDED
File without changes
|
dreamcoder/domains/misc/algolispPrimitives.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#napsPrimitives.py
|
2 |
+
from dreamcoder.program import Primitive, Program
|
3 |
+
from dreamcoder.grammar import Grammar
|
4 |
+
from dreamcoder.type import tlist, arrow, baseType #, t0, t1, t2
|
5 |
+
|
6 |
+
#from functools import reduce
|
7 |
+
|
8 |
+
|
9 |
+
#Internal TYPES:
|
10 |
+
# NUMBER
|
11 |
+
# BOOLEAN
|
12 |
+
# NOTFUNCTYPE
|
13 |
+
# Type
|
14 |
+
# ANYTYPE
|
15 |
+
|
16 |
+
#types
|
17 |
+
tsymbol = baseType("symbol")
|
18 |
+
#PROGRAM = SYMBOL = constant | argument | function_call | function | lambda
|
19 |
+
tconstant = baseType("constant")
|
20 |
+
tfunction = baseType("function")
|
21 |
+
|
22 |
+
f = dict([("|||","triple_or"),
|
23 |
+
("reduce","reduce"),
|
24 |
+
("+","+"),
|
25 |
+
("len","len"),
|
26 |
+
("map","map"),
|
27 |
+
("filter","filter"),
|
28 |
+
("-","-"),
|
29 |
+
("*","*"),
|
30 |
+
("partial0","partial0"),
|
31 |
+
("if","if"),
|
32 |
+
("lambda1","lambda1"),
|
33 |
+
("==","eq"),
|
34 |
+
("range","range"),
|
35 |
+
("digits","digits"),
|
36 |
+
("slice","slice"),
|
37 |
+
("reverse","reverse"),
|
38 |
+
("lambda2","lambda2"),
|
39 |
+
("deref","deref"),
|
40 |
+
("partial1","partial1"),
|
41 |
+
("/","div"),
|
42 |
+
("<","less_than"),
|
43 |
+
(">","greater_than"),
|
44 |
+
("min","min"),
|
45 |
+
("combine","combine"),
|
46 |
+
("head","head"),
|
47 |
+
("is_prime","is_prime"),
|
48 |
+
("false","false"),
|
49 |
+
("||","or"),
|
50 |
+
("10","10"),
|
51 |
+
("self","self"),
|
52 |
+
("max","max"),
|
53 |
+
("sort","sort"),
|
54 |
+
("%","mod"),
|
55 |
+
("invoke1","invoke1"),
|
56 |
+
("!","bang"),
|
57 |
+
("square","square"),
|
58 |
+
("str_concat","str_concat"),
|
59 |
+
("strlen","strlen"),
|
60 |
+
("<=","leq"),
|
61 |
+
("int-deref","int-deref"),
|
62 |
+
("str_split","str_split"),
|
63 |
+
("str_index","str_index"),
|
64 |
+
("floor","floor"),
|
65 |
+
("sqrt","sqrt"),
|
66 |
+
("str_min","str_min"),
|
67 |
+
("&&","AND"),
|
68 |
+
("is_sorted","is_sorted"),
|
69 |
+
("str_max","str_max"),
|
70 |
+
(">=","geq")])
|
71 |
+
|
72 |
+
fn_lookup = {
|
73 |
+
**f
|
74 |
+
}
|
75 |
+
|
76 |
+
c = dict(
|
77 |
+
[("0","0"),
|
78 |
+
("a","a"),
|
79 |
+
("arg1","arg1"),
|
80 |
+
("1","1"),
|
81 |
+
("b","b"),
|
82 |
+
("2","2"),
|
83 |
+
("c","c"),
|
84 |
+
("arg2","arg2"),
|
85 |
+
("d","d"),
|
86 |
+
("false","false"),
|
87 |
+
("10","10"),
|
88 |
+
("self","self"),
|
89 |
+
("1000000000","1000000000"),
|
90 |
+
("\"\"", "empty_str"),
|
91 |
+
("e","e"),
|
92 |
+
("40","40"),
|
93 |
+
("f","f"),
|
94 |
+
("\" \"", "space"),
|
95 |
+
("g","g"),
|
96 |
+
("\"z\"","z"),
|
97 |
+
("true","true"),
|
98 |
+
("h","h"),
|
99 |
+
("i","i"),
|
100 |
+
("j","j"),
|
101 |
+
("k","k"),
|
102 |
+
("l","l")]
|
103 |
+
)
|
104 |
+
|
105 |
+
const_lookup = {
|
106 |
+
**c
|
107 |
+
}
|
108 |
+
|
109 |
+
primitive_lookup = {**const_lookup, **fn_lookup}
|
110 |
+
#Do i need arguments??
|
111 |
+
|
112 |
+
def _fn_call(f):
|
113 |
+
#print("f", f)
|
114 |
+
def inner(sx):
|
115 |
+
#print("sx", sx)
|
116 |
+
if not type(sx) == list:
|
117 |
+
sx = [sx]
|
118 |
+
return [f] + sx
|
119 |
+
return lambda sx: inner(sx)
|
120 |
+
|
121 |
+
def algolispPrimitives():
|
122 |
+
return [
|
123 |
+
Primitive("fn_call", arrow(tfunction, tlist(tsymbol), tsymbol), _fn_call),
|
124 |
+
|
125 |
+
Primitive("lambda1_call", arrow(tfunction, tlist(tsymbol), tsymbol), lambda f: lambda sx: ["lambda1", [f] + sx] if type(sx)==list else ["lambda1", [f] + [sx]] ),
|
126 |
+
Primitive("lambda2_call", arrow(tfunction, tlist(tsymbol), tsymbol), lambda f: lambda sx: ["lambda2", [f] + sx] if type(sx)==list else ["lambda2", [f] + [sx]] ),
|
127 |
+
#symbol converters:
|
128 |
+
# SYMBOL = constant | argument | function_call | function | lambda
|
129 |
+
Primitive("symbol_constant", arrow(tconstant, tsymbol), lambda x: x),
|
130 |
+
Primitive("symbol_function", arrow(tfunction, tsymbol), lambda x: x),
|
131 |
+
#list converters
|
132 |
+
Primitive('list_init_symbol', arrow(tsymbol, tlist(tsymbol)), lambda symbol: [symbol] ),
|
133 |
+
Primitive('list_add_symbol', arrow(tsymbol, tlist(tsymbol), tlist(tsymbol)), lambda symbol: lambda symbols: symbols + [symbol] if type(symbols) == list else [symbols] + [symbol])
|
134 |
+
] + [
|
135 |
+
#functions:
|
136 |
+
Primitive(ec_name, tfunction, algo_name) for algo_name, ec_name in fn_lookup.items()
|
137 |
+
] + [
|
138 |
+
#Constants
|
139 |
+
Primitive(ec_name, tconstant, algo_name) for algo_name, ec_name in const_lookup.items()
|
140 |
+
]
|
141 |
+
|
142 |
+
|
143 |
+
#for first pass, can just hard code vars and maps n stuff
|
144 |
+
|
145 |
+
|
146 |
+
def algolispProductions():
|
147 |
+
return [(0.0, prim) for prim in algolispPrimitives()]
|
148 |
+
|
149 |
+
|
150 |
+
algolisp_input_vocab = [
|
151 |
+
"<S>",
|
152 |
+
"</S>",
|
153 |
+
"<UNK>",
|
154 |
+
"|||",
|
155 |
+
"(",
|
156 |
+
")",
|
157 |
+
"a",
|
158 |
+
"b",
|
159 |
+
"of",
|
160 |
+
"the",
|
161 |
+
"0",
|
162 |
+
",",
|
163 |
+
"arg1",
|
164 |
+
"c",
|
165 |
+
"and",
|
166 |
+
"1",
|
167 |
+
"reduce",
|
168 |
+
"+",
|
169 |
+
"int[]",
|
170 |
+
"in",
|
171 |
+
"given",
|
172 |
+
"numbers",
|
173 |
+
"int",
|
174 |
+
"is",
|
175 |
+
"len",
|
176 |
+
"map",
|
177 |
+
"digits",
|
178 |
+
"d",
|
179 |
+
"number",
|
180 |
+
"array",
|
181 |
+
"-",
|
182 |
+
"filter",
|
183 |
+
"to",
|
184 |
+
"range",
|
185 |
+
"are",
|
186 |
+
"*",
|
187 |
+
"partial0",
|
188 |
+
"2",
|
189 |
+
"if",
|
190 |
+
"reverse",
|
191 |
+
"that",
|
192 |
+
"elements",
|
193 |
+
"lambda1",
|
194 |
+
"==",
|
195 |
+
"an",
|
196 |
+
"arg2",
|
197 |
+
"values",
|
198 |
+
"slice",
|
199 |
+
"element",
|
200 |
+
"lambda2",
|
201 |
+
"deref",
|
202 |
+
"you",
|
203 |
+
"partial1",
|
204 |
+
"e",
|
205 |
+
"find",
|
206 |
+
"your",
|
207 |
+
"task",
|
208 |
+
"compute",
|
209 |
+
"among",
|
210 |
+
"from",
|
211 |
+
"consider",
|
212 |
+
"first",
|
213 |
+
"than",
|
214 |
+
"value",
|
215 |
+
"/",
|
216 |
+
"what",
|
217 |
+
"arrays",
|
218 |
+
"with",
|
219 |
+
"<",
|
220 |
+
"length",
|
221 |
+
">",
|
222 |
+
"be",
|
223 |
+
"min",
|
224 |
+
"end",
|
225 |
+
"sum",
|
226 |
+
"one",
|
227 |
+
"head",
|
228 |
+
"f",
|
229 |
+
"by",
|
230 |
+
"combine",
|
231 |
+
"segment",
|
232 |
+
"coordinates",
|
233 |
+
"not",
|
234 |
+
"string",
|
235 |
+
"is_prime",
|
236 |
+
"false",
|
237 |
+
"||",
|
238 |
+
"at",
|
239 |
+
"10",
|
240 |
+
"half",
|
241 |
+
"position",
|
242 |
+
"self",
|
243 |
+
"subsequence",
|
244 |
+
"after",
|
245 |
+
"such",
|
246 |
+
"max",
|
247 |
+
"prime",
|
248 |
+
"sort",
|
249 |
+
"let",
|
250 |
+
"%",
|
251 |
+
"longest",
|
252 |
+
"inclusive",
|
253 |
+
"which",
|
254 |
+
"invoke1",
|
255 |
+
"1000000000",
|
256 |
+
"all",
|
257 |
+
"positions",
|
258 |
+
"!",
|
259 |
+
"square",
|
260 |
+
"its",
|
261 |
+
"has",
|
262 |
+
"reversed",
|
263 |
+
"another",
|
264 |
+
"less",
|
265 |
+
"each",
|
266 |
+
"\"\"",
|
267 |
+
"order",
|
268 |
+
"largest",
|
269 |
+
"maximum",
|
270 |
+
"g",
|
271 |
+
"last",
|
272 |
+
"smallest",
|
273 |
+
"times",
|
274 |
+
"strictly",
|
275 |
+
"40",
|
276 |
+
"smaller",
|
277 |
+
"indexes",
|
278 |
+
"str_concat",
|
279 |
+
"strlen",
|
280 |
+
"two",
|
281 |
+
"starting",
|
282 |
+
"<=",
|
283 |
+
"on",
|
284 |
+
"greater",
|
285 |
+
"how",
|
286 |
+
"many",
|
287 |
+
"int-deref",
|
288 |
+
"prefix",
|
289 |
+
"bigger",
|
290 |
+
"only",
|
291 |
+
"str_split",
|
292 |
+
"\" \"",
|
293 |
+
"str_index",
|
294 |
+
"can",
|
295 |
+
"plus",
|
296 |
+
"squared",
|
297 |
+
"product",
|
298 |
+
"strings",
|
299 |
+
"floor",
|
300 |
+
"sqrt",
|
301 |
+
"before",
|
302 |
+
"it",
|
303 |
+
"concatenation",
|
304 |
+
"index",
|
305 |
+
"as",
|
306 |
+
"define",
|
307 |
+
"multiplied",
|
308 |
+
"biggest",
|
309 |
+
"rounded",
|
310 |
+
"down",
|
311 |
+
"string[]",
|
312 |
+
"equal",
|
313 |
+
"integer",
|
314 |
+
"also",
|
315 |
+
"based",
|
316 |
+
"sorting",
|
317 |
+
"replace",
|
318 |
+
"becomes",
|
319 |
+
"single",
|
320 |
+
"digit",
|
321 |
+
"characters",
|
322 |
+
"keeping",
|
323 |
+
"including",
|
324 |
+
"h",
|
325 |
+
"larger",
|
326 |
+
"written",
|
327 |
+
"divisible",
|
328 |
+
"previous",
|
329 |
+
"subarray",
|
330 |
+
"mininum",
|
331 |
+
"second",
|
332 |
+
"middle",
|
333 |
+
"same",
|
334 |
+
"th",
|
335 |
+
"median",
|
336 |
+
"till",
|
337 |
+
"integers",
|
338 |
+
"sequence",
|
339 |
+
"for",
|
340 |
+
"indices",
|
341 |
+
"between",
|
342 |
+
"when",
|
343 |
+
"doubled",
|
344 |
+
"ending",
|
345 |
+
"even",
|
346 |
+
"multiply",
|
347 |
+
"squares",
|
348 |
+
"fibonacci",
|
349 |
+
"exclusive",
|
350 |
+
"odd",
|
351 |
+
"keep",
|
352 |
+
"whether",
|
353 |
+
"minimum",
|
354 |
+
"except",
|
355 |
+
"letters",
|
356 |
+
"appearing",
|
357 |
+
"letter",
|
358 |
+
"consecutive",
|
359 |
+
"character",
|
360 |
+
"factorial",
|
361 |
+
"chosen",
|
362 |
+
"start",
|
363 |
+
"begin",
|
364 |
+
"themselves",
|
365 |
+
"\"z\"",
|
366 |
+
"str_min",
|
367 |
+
"remove",
|
368 |
+
"present",
|
369 |
+
"exist",
|
370 |
+
"appear",
|
371 |
+
"starts",
|
372 |
+
"i",
|
373 |
+
"located",
|
374 |
+
"true",
|
375 |
+
"&&",
|
376 |
+
"found",
|
377 |
+
"discarding",
|
378 |
+
"is_sorted",
|
379 |
+
"removing",
|
380 |
+
"do",
|
381 |
+
"increasing",
|
382 |
+
"exceed",
|
383 |
+
"ascending",
|
384 |
+
"difference",
|
385 |
+
"decremented",
|
386 |
+
"existing",
|
387 |
+
"alphabetically",
|
388 |
+
"words",
|
389 |
+
"added",
|
390 |
+
"incremented",
|
391 |
+
"backwards",
|
392 |
+
"individual",
|
393 |
+
"lexicographically",
|
394 |
+
"separate",
|
395 |
+
"abbreviation",
|
396 |
+
"str_max",
|
397 |
+
"increment",
|
398 |
+
"consisting",
|
399 |
+
"equals",
|
400 |
+
"having",
|
401 |
+
"discard",
|
402 |
+
"descending",
|
403 |
+
"decreasing",
|
404 |
+
"sorted",
|
405 |
+
"being",
|
406 |
+
"where",
|
407 |
+
"right",
|
408 |
+
"there",
|
409 |
+
"ordinal",
|
410 |
+
"have",
|
411 |
+
"s",
|
412 |
+
"going",
|
413 |
+
"'",
|
414 |
+
"add",
|
415 |
+
"space",
|
416 |
+
"decrement",
|
417 |
+
"those",
|
418 |
+
"whitespaces",
|
419 |
+
"spaces",
|
420 |
+
"subtract",
|
421 |
+
"remaining",
|
422 |
+
"following",
|
423 |
+
"or",
|
424 |
+
"out",
|
425 |
+
"ordered",
|
426 |
+
"minimal",
|
427 |
+
"itself",
|
428 |
+
"symmetric",
|
429 |
+
"read",
|
430 |
+
"increases",
|
431 |
+
"word",
|
432 |
+
"immidiately",
|
433 |
+
"excluding",
|
434 |
+
"j",
|
435 |
+
"omitting",
|
436 |
+
"reads",
|
437 |
+
"maximal",
|
438 |
+
">=",
|
439 |
+
"compare",
|
440 |
+
"form",
|
441 |
+
"absent",
|
442 |
+
"missing",
|
443 |
+
"cannot",
|
444 |
+
"whose",
|
445 |
+
"count",
|
446 |
+
"lowest",
|
447 |
+
"both",
|
448 |
+
"ends",
|
449 |
+
"beginning",
|
450 |
+
"left",
|
451 |
+
"mean",
|
452 |
+
"average",
|
453 |
+
"obtained",
|
454 |
+
"writing",
|
455 |
+
"result",
|
456 |
+
"joining",
|
457 |
+
"together",
|
458 |
+
"increase",
|
459 |
+
"highest",
|
460 |
+
"comparing",
|
461 |
+
"forms",
|
462 |
+
"avg",
|
463 |
+
"outside",
|
464 |
+
"positive",
|
465 |
+
"summed",
|
466 |
+
"belonging",
|
467 |
+
"lexicographical",
|
468 |
+
"rest",
|
469 |
+
"belong",
|
470 |
+
"inclucing",
|
471 |
+
"lexical",
|
472 |
+
"alphabetical",
|
473 |
+
"dictionary",
|
474 |
+
"k",
|
475 |
+
"negative",
|
476 |
+
"lexicographic",
|
477 |
+
"represents",
|
478 |
+
"delete",
|
479 |
+
"non",
|
480 |
+
"l",
|
481 |
+
"erase",
|
482 |
+
"m",
|
483 |
+
"comes",
|
484 |
+
"up",
|
485 |
+
"comparison",
|
486 |
+
"during",
|
487 |
+
"'s value is the largest inclusive, which is strictly less than maximum element in numbers from 1 to the element in `a` which'",
|
488 |
+
"'s value is the biggest (inclusive), which is strictly less than maximum element of range from 1 to the element in `a` which'",
|
489 |
+
"'s value is the highest, which is strictly less than maximum element among sequence of digits of the element in `a` which'"]
|
490 |
+
|
491 |
+
|
492 |
+
if __name__ == "__main__":
|
493 |
+
#g = Grammar.uniform(deepcoderPrimitives())
|
494 |
+
|
495 |
+
g = Grammar.fromProductions(algolispProductions(), logVariable=.9)
|
496 |
+
|
497 |
+
#p=Program.parse("(lambda (fn_call filter (list_add_symbol (lambda1_call == (list_add_symbol 1 (list_init_symbol (fn_call mod ( list_add_symbol 2 (list_init_symbol arg1)) ))) ) (list_init_symbol $0)) )")
|
498 |
+
p=Program.parse("(lambda (fn_call filter (list_add_symbol (lambda1_call eq (list_add_symbol (symbol_constant 1) (list_init_symbol (fn_call mod ( list_add_symbol (symbol_constant 2) (list_init_symbol (symbol_constant arg1))) ))) ) (list_init_symbol (symbol_constant $0)))))")
|
499 |
+
|
500 |
+
print(p)
|
501 |
+
|
502 |
+
#tree = p.evaluate(["a"])
|
503 |
+
tree = p.evaluate([])
|
504 |
+
print(tree("a"))
|
505 |
+
|
506 |
+
#
|
507 |
+
|
508 |
+
|
dreamcoder/domains/misc/deepcoderPrimitives.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.program import Primitive, prettyProgram
|
2 |
+
from dreamcoder.grammar import Grammar
|
3 |
+
from dreamcoder.type import tlist, tint, arrow, baseType #, t0, t1, t2
|
4 |
+
|
5 |
+
#from functools import reduce
|
6 |
+
|
7 |
+
|
8 |
+
#todo
|
9 |
+
int_to_int = baseType("int_to_int")
|
10 |
+
int_to_bool = baseType("int_to_bool")
|
11 |
+
int_to_int_to_int = baseType("int_to_int_to_int")
|
12 |
+
|
13 |
+
|
14 |
+
#deepcoderPrimitives
|
15 |
+
Null = 300 #or perhaps something else, like "an integer outside the working range"?
|
16 |
+
|
17 |
+
def _head(xs): return xs[0] if len(xs)>0 else Null
|
18 |
+
def _tail(xs): return xs[-1] if len(xs)>0 else Null
|
19 |
+
def _take(n): return lambda xs: xs[:n]
|
20 |
+
def _drop(n): return lambda xs: xs[n:]
|
21 |
+
def _access(n): return lambda xs: xs[n] if n>=0 and len(xs)>n else Null
|
22 |
+
def _minimum(xs): return min(xs) if len(xs)>0 else Null
|
23 |
+
def _maximum(xs): return max(xs) if len(xs)>0 else Null
|
24 |
+
def _reverse(xs): return list(reversed(xs))
|
25 |
+
def _sort(xs): return sorted(xs)
|
26 |
+
def _sum(xs): return sum(xs)
|
27 |
+
|
28 |
+
#higher order:
|
29 |
+
def _map(f): return lambda l: list(map(f, l))
|
30 |
+
def _filter(f): return lambda l: list(filter(f, l))
|
31 |
+
def _count(f): return lambda l: len([x for x in l if f(x)])
|
32 |
+
def _zipwith(f): return lambda xs: lambda ys: [f(x)(y) for (x, y) in zip(xs, ys)]
|
33 |
+
def _scanl1(f):
|
34 |
+
def _inner(xs):
|
35 |
+
ys = []
|
36 |
+
if len(xs) > 0:
|
37 |
+
ys.append(xs[0])
|
38 |
+
for i in range(1, len(xs)):
|
39 |
+
ys.append( f(ys[i-1])(xs[i]))
|
40 |
+
return ys
|
41 |
+
return _inner
|
42 |
+
|
43 |
+
#int to int:
|
44 |
+
def _succ(x): return x+1
|
45 |
+
def _pred(x): return x-1
|
46 |
+
def _double(x): return x*2
|
47 |
+
def _half(x): return int(x/2)
|
48 |
+
def _negate(x): return -x
|
49 |
+
def _square(x): return x**2
|
50 |
+
def _triple(x): return x*3
|
51 |
+
def _third(x): return int(x/3)
|
52 |
+
def _quad(x): return x*4
|
53 |
+
def _quarter(x): return int(x/4)
|
54 |
+
|
55 |
+
#int to bool:
|
56 |
+
def _pos(x): return x>0
|
57 |
+
def _neg(x): return x<0
|
58 |
+
def _even(x): return x%2==0
|
59 |
+
def _odd(x): return x%2==1
|
60 |
+
|
61 |
+
#int to int to int:
|
62 |
+
def _add(x): return lambda y: x + y
|
63 |
+
def _sub(x): return lambda y: x - y
|
64 |
+
def _mult(x): return lambda y: x * y
|
65 |
+
def _min(x): return lambda y: _minimum([x,y])
|
66 |
+
def _max(x): return lambda y: _maximum([x,y])
|
67 |
+
|
68 |
+
def deepcoderPrimitives():
|
69 |
+
return [
|
70 |
+
Primitive("HEAD", arrow(tlist(tint), tint), _head),
|
71 |
+
Primitive("LAST", arrow(tlist(tint), tint), _tail),
|
72 |
+
Primitive("TAKE", arrow(tint, tlist(tint), tlist(tint)), _take),
|
73 |
+
Primitive("DROP", arrow(tint, tlist(tint), tlist(tint)), _drop),
|
74 |
+
Primitive("ACCESS", arrow(tint, tlist(tint), tint), _access),
|
75 |
+
Primitive("MINIMUM", arrow(tlist(tint), tint), _minimum),
|
76 |
+
Primitive("MAXIMUM", arrow(tlist(tint), tint), _maximum),
|
77 |
+
Primitive("REVERSE", arrow(tlist(tint), tlist(tint)), _reverse),
|
78 |
+
Primitive("SORT", arrow(tlist(tint), tlist(tint)), _sort),
|
79 |
+
Primitive("SUM", arrow(tlist(tint), tint), _sum)
|
80 |
+
] + [
|
81 |
+
Primitive("MAP", arrow(int_to_int, tlist(tint), tlist(tint)), _map), #is this okay???
|
82 |
+
Primitive("FILTER", arrow(int_to_bool, tlist(tint), tlist(tint)), _filter), #is this okay???
|
83 |
+
Primitive("COUNT", arrow(int_to_bool, tlist(tint), tint), _count), #is this okay???
|
84 |
+
Primitive("ZIPWITH", arrow(int_to_int_to_int, tlist(tint), tlist(tint), tlist(tint)), _zipwith), #is this okay???
|
85 |
+
Primitive("SCANL1", arrow(int_to_int_to_int, tlist(tint), tlist(tint)), _scanl1), #is this okay???
|
86 |
+
] + [
|
87 |
+
Primitive("INC", int_to_int, _succ),
|
88 |
+
Primitive("DEC", int_to_int, _pred),
|
89 |
+
Primitive("SHL", int_to_int, _double),
|
90 |
+
Primitive("SHR", int_to_int, _half),
|
91 |
+
Primitive("doNEG", int_to_int, _negate),
|
92 |
+
Primitive("SQR", int_to_int, _square),
|
93 |
+
Primitive("MUL3", int_to_int, _triple),
|
94 |
+
Primitive("DIV3", int_to_int, _third),
|
95 |
+
Primitive("MUL4", int_to_int, _quad),
|
96 |
+
Primitive("DIV4", int_to_int, _quarter),
|
97 |
+
] + [
|
98 |
+
Primitive("isPOS", int_to_bool, _pos),
|
99 |
+
Primitive("isNEG", int_to_bool, _neg),
|
100 |
+
Primitive("isEVEN", int_to_bool, _even),
|
101 |
+
Primitive("isODD", int_to_bool, _odd),
|
102 |
+
] + [
|
103 |
+
Primitive("+", int_to_int_to_int, _add),
|
104 |
+
Primitive("-", int_to_int_to_int, _sub),
|
105 |
+
Primitive("*", int_to_int_to_int, _mult),
|
106 |
+
Primitive("MIN", int_to_int_to_int, _min),
|
107 |
+
Primitive("MAX", int_to_int_to_int, _max)
|
108 |
+
]
|
109 |
+
|
110 |
+
def OldDeepcoderPrimitives():
|
111 |
+
return [
|
112 |
+
Primitive("head", arrow(tlist(tint), tint), _head),
|
113 |
+
Primitive("tail", arrow(tlist(tint), tint), _tail),
|
114 |
+
Primitive("take", arrow(tint, tlist(tint), tlist(tint)), _take),
|
115 |
+
Primitive("drop", arrow(tint, tlist(tint), tlist(tint)), _drop),
|
116 |
+
Primitive("access", arrow(tint, tlist(tint), tint), _access),
|
117 |
+
Primitive("minimum", arrow(tlist(tint), tint), _minimum),
|
118 |
+
Primitive("maximum", arrow(tlist(tint), tint), _maximum),
|
119 |
+
Primitive("reverse", arrow(tlist(tint), tlist(tint)), _reverse),
|
120 |
+
Primitive("sort", arrow(tlist(tint), tlist(tint)), _sort),
|
121 |
+
Primitive("sum", arrow(tlist(tint), tint), _sum)
|
122 |
+
] + [
|
123 |
+
Primitive("map", arrow(int_to_int, tlist(tint), tlist(tint)), _map), #is this okay???
|
124 |
+
Primitive("filter_int", arrow(int_to_bool, tlist(tint), tlist(tint)), _filter), #is this okay???
|
125 |
+
Primitive("count", arrow(int_to_bool, tlist(tint), tint), _count), #is this okay???
|
126 |
+
Primitive("zipwith", arrow(int_to_int_to_int, tlist(tint), tlist(tint), tlist(tint)), _zipwith), #is this okay???
|
127 |
+
Primitive("scanl1", arrow(int_to_int_to_int, tlist(tint), tlist(tint)), _scanl1), #is this okay???
|
128 |
+
# ] + [
|
129 |
+
# Primitive("succ", arrow(tint, tint), _succ),
|
130 |
+
# Primitive("pred", arrow(tint, tint), _pred),
|
131 |
+
# Primitive("double", arrow(tint, tint), _double),
|
132 |
+
# Primitive("half", arrow(tint, tint), _half),
|
133 |
+
# Primitive("neg", arrow(tint, tint), _neg),
|
134 |
+
# Primitive("square", arrow(tint, tint), _square),
|
135 |
+
# Primitive("triple", arrow(tint, tint), _triple),
|
136 |
+
# Primitive("third", arrow(tint, tint), _third),
|
137 |
+
# Primitive("quad", arrow(tint, tint), _quad),
|
138 |
+
# Primitive("quarter", arrow(tint, tint), _quarter),
|
139 |
+
# ] + [
|
140 |
+
# Primitive("pos", arrow(tint, tbool), _pos),
|
141 |
+
# Primitive("neg", arrow(tint, tbool), _neg),
|
142 |
+
# Primitive("even", arrow(tint, tbool), _even),
|
143 |
+
# Primitive("odd", arrow(tint, tbool), _odd),
|
144 |
+
# ] + [
|
145 |
+
# Primitive("add", arrow(tint, tint, tint), _add),
|
146 |
+
# Primitive("sub", arrow(tint, tint, tint), _sub),
|
147 |
+
# Primitive("mult", arrow(tint, tint, tint), _mult),
|
148 |
+
# Primitive("min", arrow(tint, tint, tint), _min),
|
149 |
+
# Primitive("max", arrow(tint, tint, tint), _max)
|
150 |
+
] + [
|
151 |
+
Primitive("succ_fn", int_to_int, _succ),
|
152 |
+
Primitive("pred_fn", int_to_int, _pred),
|
153 |
+
Primitive("double_fn", int_to_int, _double),
|
154 |
+
Primitive("half_fn", int_to_int, _half),
|
155 |
+
Primitive("negate_fn", int_to_int, _negate),
|
156 |
+
Primitive("square_fn", int_to_int, _square),
|
157 |
+
Primitive("triple_fn", int_to_int, _triple),
|
158 |
+
Primitive("third_fn", int_to_int, _third),
|
159 |
+
Primitive("quad_fn", int_to_int, _quad),
|
160 |
+
Primitive("quarter_fn", int_to_int, _quarter),
|
161 |
+
] + [
|
162 |
+
Primitive("pos_fn", int_to_bool, _pos),
|
163 |
+
Primitive("neg_fn", int_to_bool, _neg),
|
164 |
+
Primitive("even_fn", int_to_bool, _even),
|
165 |
+
Primitive("odd_fn", int_to_bool, _odd),
|
166 |
+
] + [
|
167 |
+
Primitive("add_fn", int_to_int_to_int, _add),
|
168 |
+
Primitive("sub_fn", int_to_int_to_int, _sub),
|
169 |
+
Primitive("mult_fn", int_to_int_to_int, _mult),
|
170 |
+
Primitive("min_fn", int_to_int_to_int, _min),
|
171 |
+
Primitive("max_fn", int_to_int_to_int, _max)
|
172 |
+
]
|
173 |
+
|
174 |
+
def deepcoderProductions():
|
175 |
+
return [(0.0, prim) for prim in deepcoderPrimitives()]
|
176 |
+
|
177 |
+
def flatten_program(p):
|
178 |
+
string = p.show(False)
|
179 |
+
num_inputs = string.count('lambda')
|
180 |
+
string = string.replace('lambda', '')
|
181 |
+
string = string.replace('(', '')
|
182 |
+
string = string.replace(')', '')
|
183 |
+
#remove '_fn' (optional)
|
184 |
+
for i in range(num_inputs):
|
185 |
+
string = string.replace('$' + str(num_inputs-i-1),'input_' + str(i))
|
186 |
+
string = string.split(' ')
|
187 |
+
string = list(filter(lambda x: x is not '', string))
|
188 |
+
return string
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
#g = Grammar.uniform(deepcoderPrimitives())
|
192 |
+
g = Grammar.fromProductions(deepcoderProductions(), logVariable=.9)
|
193 |
+
request = arrow(tlist(tint), tint, tint)
|
194 |
+
p = g.sample(request)
|
195 |
+
print("request:", request)
|
196 |
+
print("program:")
|
197 |
+
print(prettyProgram(p))
|
198 |
+
print("flattened_program:")
|
199 |
+
flat = flatten_program(p)
|
200 |
+
print(flat)
|
201 |
+
|
202 |
+
#robustfill output = names from productions + input_0-2 or 3
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
# # with open("/home/ellisk/om/ec/experimentOutputs/list_aic=1.0_arity=3_ET=1800_expandFrontier=2.0_it=4_likelihoodModel=all-or-nothing_MF=5_baseline=False_pc=10.0_L=1.0_K=5_rec=False.pickle", "rb") as handle:
|
209 |
+
# # b = pickle.load(handle).grammars[-1]
|
210 |
+
# # print b
|
211 |
+
|
212 |
+
# p = Program.parse(
|
213 |
+
# "(lambda (lambda (lambda (if (empty? $0) empty (cons (+ (car $1) (car $0)) ($2 (cdr $1) (cdr $0)))))))")
|
214 |
+
# t = arrow(tlist(tint), tlist(tint), tlist(tint)) # ,tlist(tbool))
|
215 |
+
# print(g.logLikelihood(arrow(t, t), p))
|
216 |
+
# assert False
|
217 |
+
# print(b.logLikelihood(arrow(t, t), p))
|
218 |
+
|
219 |
+
# # p = Program.parse("""(lambda (lambda
|
220 |
+
# # (unfold 0
|
221 |
+
# # (lambda (+ (index $0 $2) (index $0 $1)))
|
222 |
+
# # (lambda (1+ $0))
|
223 |
+
# # (lambda (eq? $0 (length $1))))))
|
224 |
+
# # """)
|
225 |
+
# p = Program.parse("""(lambda (lambda
|
226 |
+
# (map (lambda (+ (index $0 $2) (index $0 $1))) (range (length $0)) )))""")
|
227 |
+
# # .replace("unfold", "#(lambda (lambda (lambda (lambda (fix1 $0 (lambda (lambda (#(lambda (lambda (lambda (if $0 empty (cons $1 $2))))) ($1 ($3 $0)) ($4 $0) ($5 $0)))))))))").\
|
228 |
+
# # replace("length", "#(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ ($1 (cdr $0)) 1))))))").\
|
229 |
+
# # replace("forloop", "(#(lambda (lambda (lambda (lambda (fix1 $0 (lambda (lambda (#(lambda (lambda (lambda (if $0 empty (cons $1 $2))))) ($1 ($3 $0)) ($4 $0) ($5 $0))))))))) (lambda (#(eq? 0) $0)) $0 (lambda (#(lambda (- $0 1)) $0)))").\
|
230 |
+
# # replace("inc","#(lambda (+ $0 1))").\
|
231 |
+
# # replace("drop","#(lambda (lambda (fix2 $0 $1 (lambda (lambda (lambda (if
|
232 |
+
# # (#(eq? 0) $1) $0 (cdr ($2 (- $1 1) $0)))))))))"))
|
233 |
+
# print(p)
|
234 |
+
# print(g.logLikelihood(t, p))
|
235 |
+
# assert False
|
236 |
+
|
237 |
+
# print("??")
|
238 |
+
# p = Program.parse(
|
239 |
+
# "#(lambda (#(lambda (lambda (lambda (fix1 $0 (lambda (lambda (if (empty? $0) $3 ($4 (car $0) ($1 (cdr $0)))))))))) (lambda $1) 1))")
|
240 |
+
# for j in range(10):
|
241 |
+
# l = list(range(j))
|
242 |
+
# print(l, p.evaluate([])(lambda x: x * 2)(l))
|
243 |
+
# print()
|
244 |
+
# print()
|
245 |
+
|
246 |
+
# print("multiply")
|
247 |
+
# p = Program.parse(
|
248 |
+
# "(lambda (lambda (lambda (if (eq? $0 0) 0 (+ $1 ($2 $1 (- $0 1)))))))")
|
249 |
+
# print(g.logLikelihood(arrow(arrow(tint, tint, tint), tint, tint, tint), p))
|
250 |
+
# print()
|
251 |
+
|
252 |
+
# print("take until 0")
|
253 |
+
# p = Program.parse("(lambda (lambda (if (eq? $1 0) empty (cons $1 $0))))")
|
254 |
+
# print(g.logLikelihood(arrow(tint, tlist(tint), tlist(tint)), p))
|
255 |
+
# print()
|
256 |
+
|
257 |
+
# print("countdown primitive")
|
258 |
+
# p = Program.parse(
|
259 |
+
# "(lambda (lambda (if (eq? $0 0) empty (cons (+ $0 1) ($1 (- $0 1))))))")
|
260 |
+
# print(
|
261 |
+
# g.logLikelihood(
|
262 |
+
# arrow(
|
263 |
+
# arrow(
|
264 |
+
# tint, tlist(tint)), arrow(
|
265 |
+
# tint, tlist(tint))), p))
|
266 |
+
# print(_fix(9)(p.evaluate([])))
|
267 |
+
# print("countdown w/ better primitives")
|
268 |
+
# p = Program.parse(
|
269 |
+
# "(lambda (lambda (if (eq0 $0) empty (cons (+1 $0) ($1 (-1 $0))))))")
|
270 |
+
# print(
|
271 |
+
# g.logLikelihood(
|
272 |
+
# arrow(
|
273 |
+
# arrow(
|
274 |
+
# tint, tlist(tint)), arrow(
|
275 |
+
# tint, tlist(tint))), p))
|
276 |
+
|
277 |
+
# print()
|
278 |
+
|
279 |
+
# print("prepend zeros")
|
280 |
+
# p = Program.parse(
|
281 |
+
# "(lambda (lambda (lambda (if (eq? $1 0) $0 (cons 0 ($2 (- $1 1) $0))))))")
|
282 |
+
# print(
|
283 |
+
# g.logLikelihood(
|
284 |
+
# arrow(
|
285 |
+
# arrow(
|
286 |
+
# tint,
|
287 |
+
# tlist(tint),
|
288 |
+
# tlist(tint)),
|
289 |
+
# tint,
|
290 |
+
# tlist(tint),
|
291 |
+
# tlist(tint)),
|
292 |
+
# p))
|
293 |
+
# print()
|
294 |
+
# assert False
|
295 |
+
|
296 |
+
# p = Program.parse(
|
297 |
+
# "(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ 1 ($1 (cdr $0))))))))")
|
298 |
+
# print(p.evaluate([])(list(range(17))))
|
299 |
+
# print(g.logLikelihood(arrow(tlist(tbool), tint), p))
|
300 |
+
|
301 |
+
# p = Program.parse(
|
302 |
+
# "(lambda (lambda (if (empty? $0) 0 (+ 1 ($1 (cdr $0))))))")
|
303 |
+
# print(
|
304 |
+
# g.logLikelihood(
|
305 |
+
# arrow(
|
306 |
+
# arrow(
|
307 |
+
# tlist(tbool), tint), arrow(
|
308 |
+
# tlist(tbool), tint)), p))
|
309 |
+
|
310 |
+
# p = Program.parse(
|
311 |
+
# "(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ (car $0) ($1 (cdr $0))))))))")
|
312 |
+
|
313 |
+
# print(p.evaluate([])(list(range(4))))
|
314 |
+
# print(g.logLikelihood(arrow(tlist(tint), tint), p))
|
315 |
+
|
316 |
+
# p = Program.parse(
|
317 |
+
# "(lambda (lambda (if (empty? $0) 0 (+ (car $0) ($1 (cdr $0))))))")
|
318 |
+
# print(p)
|
319 |
+
# print(
|
320 |
+
# g.logLikelihood(
|
321 |
+
# arrow(
|
322 |
+
# arrow(
|
323 |
+
# tlist(tint),
|
324 |
+
# tint),
|
325 |
+
# tlist(tint),
|
326 |
+
# tint),
|
327 |
+
# p))
|
328 |
+
|
329 |
+
# print("take")
|
330 |
+
# p = Program.parse(
|
331 |
+
# "(lambda (lambda (lambda (if (eq? $1 0) empty (cons (car $0) ($2 (- $1 1) (cdr $0)))))))")
|
332 |
+
# print(p)
|
333 |
+
# print(
|
334 |
+
# g.logLikelihood(
|
335 |
+
# arrow(
|
336 |
+
# arrow(
|
337 |
+
# tint,
|
338 |
+
# tlist(tint),
|
339 |
+
# tlist(tint)),
|
340 |
+
# tint,
|
341 |
+
# tlist(tint),
|
342 |
+
# tlist(tint)),
|
343 |
+
# p))
|
344 |
+
# assert False
|
345 |
+
|
346 |
+
# print(p.evaluate([])(list(range(4))))
|
347 |
+
# print(g.logLikelihood(arrow(tlist(tint), tlist(tint)), p))
|
348 |
+
|
349 |
+
# p = Program.parse(
|
350 |
+
# """(lambda (fix (lambda (lambda (match $0 0 (lambda (lambda (+ $1 ($3 $0))))))) $0))""")
|
351 |
+
# print(p.evaluate([])(list(range(4))))
|
352 |
+
# print(g.logLikelihood(arrow(tlist(tint), tint), p))
|
dreamcoder/domains/misc/napsPrimitives.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#napsPrimitives.py
|
2 |
+
from dreamcoder.program import Primitive, prettyProgram
|
3 |
+
from dreamcoder.grammar import Grammar
|
4 |
+
from dreamcoder.type import tlist, tint, arrow, baseType #, t0, t1, t2
|
5 |
+
|
6 |
+
#from functools import reduce
|
7 |
+
|
8 |
+
|
9 |
+
#types
|
10 |
+
PROGRAM = baseType("PROGRAM")
|
11 |
+
|
12 |
+
RECORD = baseType("RECORD")
|
13 |
+
FUNC = baseType("FUNC")
|
14 |
+
|
15 |
+
VAR = baseType("VAR")
|
16 |
+
STMT = baseType("STMT")
|
17 |
+
EXPR = baseType("EXPR")
|
18 |
+
ASSIGN = baseType("ASSIGN")
|
19 |
+
LHS = baseType("LHS")
|
20 |
+
IF = baseType("IF")
|
21 |
+
FOREACH = baseType("FOREACH")
|
22 |
+
WHILE = baseType("WHILE")
|
23 |
+
BREAK = baseType("BREAK")
|
24 |
+
CONTINUE = baseType("CONTINUE")
|
25 |
+
RETURN = baseType("RETURN")
|
26 |
+
NOOP = baseType("NOOP")
|
27 |
+
FIELD = baseType("FIELD")
|
28 |
+
CONSTANT = baseType("CONSTANT")
|
29 |
+
INVOKE = baseType("INVOKE")
|
30 |
+
TERNARY = baseType("TERNARY")
|
31 |
+
CAST = baseType("CAST")
|
32 |
+
TYPE = baseType("TYPE")
|
33 |
+
|
34 |
+
#other types
|
35 |
+
function_name = baseType("function_name")
|
36 |
+
field_name = baseType("field_name")
|
37 |
+
name = baseType("name") # for records and functions
|
38 |
+
value = baseType("value")
|
39 |
+
|
40 |
+
# definitions:
|
41 |
+
|
42 |
+
def _program(records): return lambda funcs: {'types': records, 'funcs': funcs}
|
43 |
+
# record
|
44 |
+
def _func(string): return lambda tp: lambda name: lambda vars1: lambda vars2: lambda stmts: [string, tp, name, vars1, vars2, stmts]
|
45 |
+
def _var(tp): return lambda name: ['var', tp, name]
|
46 |
+
# stmt
|
47 |
+
# expr
|
48 |
+
def _assign(tp): return lambda lhs: lambda expr: ['assign', tp, lhs, expr]
|
49 |
+
# lhs
|
50 |
+
def _if(tp): return lambda expr: lambda stmts1: lambda stmts2: ['if', tp, expr, stmts1, stmts2] # TODO
|
51 |
+
def _foreach(tp): return lambda var: lambda expr: lambda stmts: ['foreach', tp, expr, stmts] # TODO
|
52 |
+
def _while(tp): return lambda expr: lambda stmts1: lambda stmts1: ['while', tp, expr, stmts1, stmts2] # or: ['while', tp, expr, [stmts1], [stmts2]] #TODO
|
53 |
+
# break
|
54 |
+
# continue
|
55 |
+
def _return(tp): return lambda expr: ['return', tp, expr]
|
56 |
+
# noop
|
57 |
+
def _field(tp): return lambda expr: lambda field_name: ['field', tp, expr, field_name]
|
58 |
+
def _constant(tp): return lambda value: ['val', tp, value] #TODO deal with value
|
59 |
+
def _invoke(tp): return lambda function_name: lambda exprs: ['invoke', tp, function_name, exprs] # TODO, deal with fn_name and lists
|
60 |
+
def _ternary(tp): return lambda expr1: lambda expr2: lambda expr3: ['?:', tp, expr1, expr2, expr3]
|
61 |
+
def _cast(tp): return lambda expr: ['cast', tp, expr]
|
62 |
+
|
63 |
+
# types:
|
64 |
+
|
65 |
+
# TODO: deal with lists - x
|
66 |
+
# TODO: deal with names
|
67 |
+
# TODO: deal with values - x
|
68 |
+
|
69 |
+
# TODO: deal with the program/record __main__ and __globals__ stuff
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def napsPrimitives():
|
74 |
+
return [
|
75 |
+
Primitive("program", arrow(tlist(RECORD), tlist(FUNC), PROGRAM), _program), # TODO
|
76 |
+
# RECORD
|
77 |
+
Primitive("func", arrow(TYPE, name, tlist(VAR), tlist(VAR), tlist(VAR), tlist(STMT)), _func('func')), # TODO
|
78 |
+
Primitive("ctor", arrow(TYPE, name, tlist(VAR), tlist(VAR), tlist(VAR), tlist(STMT)), _func('ctor')),
|
79 |
+
Primitive("var", arrow(TYPE, name, VAR), _var)
|
80 |
+
] + [
|
81 |
+
# STMT ::= EXPR | IF | FOREACH | WHILE | BREAK | CONTINUE | RETURN | NOOP
|
82 |
+
Primitive("stmt_expr", arrow(EXPR, STMT), lambda x: x),
|
83 |
+
Primitive("stmt_if", arrow(IF, STMT), lambda x: x),
|
84 |
+
Primitive("stmt_foreach", arrow(FOREACH, STMT), lambda x: x),
|
85 |
+
Primitive("stmt_while", arrow(WHILE, STMT), lambda x: x),
|
86 |
+
Primitive("stmt_break", arrow(BREAK, STMT), lambda x: x),
|
87 |
+
Primitive("stmt_continue", arrow(CONTINUE, STMT), lambda x: x),
|
88 |
+
Primitive("stmt_return", arrow(RETURN, STMT), lambda x: x),
|
89 |
+
Primitive("stmt_noop", arrow(NOOP, STMT), lambda x: x)
|
90 |
+
] + [
|
91 |
+
# EXPR ::= ASSIGN | VAR | FIELD | CONSTANT | INVOKE | TERNARY | CAST
|
92 |
+
Primitive("expr_assign", arrow(ASSIGN, EXPR), lambda x: x),
|
93 |
+
Primitive("expr_var", arrow(VAR, EXPR), lambda x: x),
|
94 |
+
Primitive("expr_field", arrow(FIELD, EXPR), lambda x: x),
|
95 |
+
Primitive("expr_constant", arrow(CONSTANT, EXPR), lambda x: x),
|
96 |
+
Primitive("expr_invoke", arrow(INVOKE, EXPR), lambda x: x),
|
97 |
+
Primitive("expr_ternary", arrow(TERNARY, EXPR), lambda x: x),
|
98 |
+
Primitive("expr_cast", arrow(CAST, EXPR), lambda x: x)
|
99 |
+
] + [
|
100 |
+
Primitive("assign", arrow(TYPE, LHS, EXPR, ASSIGN), _assign)
|
101 |
+
] + [
|
102 |
+
# LHS ::= VAR | FIELD | INVOKE
|
103 |
+
Primitive("lhs_var", arrow(VAR, LHS), lambda x: x),
|
104 |
+
Primitive("lhs_field", arrow(FIELD, LHS), lambda x: x),
|
105 |
+
Primitive("lhs_invoke", arrow(INVOKE, LHS), lambda x: x)
|
106 |
+
] + [
|
107 |
+
Primitive("if", arrow(TYPE, EXPR, tlist(STMT), tlist(STMT), IF), _if),
|
108 |
+
Primitive("foreach", arrow(TYPE, VAR, EXPR, tlist(STMT), FOREACH), _foreach),
|
109 |
+
Primitive("while", arrow(TYPE, EXPR, tlist(STMT), tlist(STMT), WHILE), _while),
|
110 |
+
Primitive("break", arrow(TYPE, BREAK), lambda tp: ['break', tp]),
|
111 |
+
Primitive("continue", arrow(TYPE, CONTINUE), lambda tp: ['continue', tp]),
|
112 |
+
Primitive("return", arrow(TYPE, EXPR, RETURN), _return),
|
113 |
+
Primitive("noop", NOOP, ['noop']),
|
114 |
+
Primitive("field", arrow(TYPE, EXPR, field_name, FIELD), _field), # TODO
|
115 |
+
Primitive("constant", arrow(TYPE, value, CONSTANT), _constant),
|
116 |
+
Primitive("invoke", arrow(TYPE, function_name, tlist(EXPR), INVOKE), _invoke), # TODO
|
117 |
+
Primitive("ternary", arrow(TYPE, EXPR, EXPR, EXPR, TERNARY), _ternary),
|
118 |
+
Primitive("cast", arrow(TYPE, EXPR, CAST), _cast)
|
119 |
+
] + [
|
120 |
+
# below are TYPE:
|
121 |
+
Primitive("bool", TYPE, 'bool'),
|
122 |
+
Primitive("char", TYPE, 'char'),
|
123 |
+
Primitive("char*", TYPE, 'char*'),
|
124 |
+
Primitive("int", TYPE, 'int'),
|
125 |
+
Primitive("real", TYPE, 'real'),
|
126 |
+
Primitive("array", arrow(TYPE, TYPE), lambda tp: tp + '*'),
|
127 |
+
Primitive("set", arrow(TYPE, TYPE), lambda tp: tp + '%'),
|
128 |
+
Primitive("map", arrow(TYPE, TYPE, TYPE), lambda tp1: lambda tp2: '<'+tp1+'|'+tp2+'>'),
|
129 |
+
Primitive("record_name", TYPE, 'record_name#') # TODO
|
130 |
+
] + [
|
131 |
+
#stuff about lists:
|
132 |
+
# STMTs, EXPRs, VARs, maybe Funcs and records
|
133 |
+
Primitive('list_init_stmt', arrow(STMT, tlist(STMT)), lambda stmt: [stmt]),
|
134 |
+
Primitive('list_add_stmt', arrow(STMT, tlist(STMT), tlist(STMT)), lambda stmt: lambda stmts: stmts + [stmt]),
|
135 |
+
Primitive('list_init_expr', arrow(EXPR, tlist(EXPR)), lambda expr: [expr]),
|
136 |
+
Primitive('list_add_expr', arrow(EXPR, tlist(EXPR), tlist(EXPR)), lambda expr: lambda exprs: exprs + [expr]),
|
137 |
+
Primitive('list_init_var', arrow(VAR, tlist(VAR)), lambda var: [var]),
|
138 |
+
Primitive('list_add_var', arrow(VAR, tlist(VAR), tlist(VAR)), lambda var: lambda _vars: _vars + [var])
|
139 |
+
] + [
|
140 |
+
# value
|
141 |
+
Primitive('0', value, 0),
|
142 |
+
Primitive("1", value, "1"),
|
143 |
+
Primitive("-1", value, "-1")
|
144 |
+
# ...
|
145 |
+
] + [
|
146 |
+
# function_name:
|
147 |
+
Primitive('+', function_name, '+'),
|
148 |
+
Primitive('&&', function_name, "&&"),
|
149 |
+
Primitive("!", function_name, "!"),
|
150 |
+
Primitive("!=", function_name, "!="),
|
151 |
+
Primitive("string_find", function_name,"string_find")
|
152 |
+
# ...
|
153 |
+
] + [
|
154 |
+
# field_name:
|
155 |
+
Primitive('', field_name, '')
|
156 |
+
# ...
|
157 |
+
] + [
|
158 |
+
#
|
159 |
+
Primitive(f'var{str(i)}', name, f'var{str(i)}') for i in range(12)
|
160 |
+
]
|
161 |
+
|
162 |
+
|
163 |
+
#for first pass, can just hard code vars and maps n stuff
|
164 |
+
|
165 |
+
def ec_prog_to_uast(prog): # TODO
|
166 |
+
# ideally, just evaluate and then parse
|
167 |
+
uast = prog.evaluate([])
|
168 |
+
return uast
|
169 |
+
|
170 |
+
def deepcoderProductions():
|
171 |
+
return [(0.0, prim) for prim in deepcoderPrimitives()]
|
172 |
+
|
173 |
+
# def flatten_program(p):
|
174 |
+
# string = p.show(False)
|
175 |
+
# num_inputs = string.count('lambda')
|
176 |
+
# string = string.replace('lambda', '')
|
177 |
+
# string = string.replace('(', '')
|
178 |
+
# string = string.replace(')', '')
|
179 |
+
# #remove '_fn' (optional)
|
180 |
+
# for i in range(num_inputs):
|
181 |
+
# string = string.replace('$' + str(num_inputs-i-1),'input_' + str(i))
|
182 |
+
# string = string.split(' ')
|
183 |
+
# string = list(filter(lambda x: x is not '', string))
|
184 |
+
# return string
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
#g = Grammar.uniform(deepcoderPrimitives())
|
188 |
+
g = Grammar.fromProductions(deepcoderProductions(), logVariable=.9)
|
189 |
+
request = arrow(tlist(tint), tint, tint)
|
190 |
+
p = g.sample(request)
|
191 |
+
print("request:", request)
|
192 |
+
print("program:")
|
193 |
+
print(prettyProgram(p))
|
194 |
+
print("flattened_program:")
|
195 |
+
flat = flatten_program(p)
|
196 |
+
print(flat)
|
197 |
+
|
198 |
+
|
dreamcoder/domains/regex/__init__.py
ADDED
File without changes
|
dreamcoder/domains/regex/groundtruthRegexes.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#dict of gt regexes
|
3 |
+
|
4 |
+
"""
|
5 |
+
pre.create(".+"),
|
6 |
+
pre.create("\d+"),
|
7 |
+
pre.create("\w+"),
|
8 |
+
pre.create("\s+"),
|
9 |
+
pre.create("\\u+"),
|
10 |
+
pre.create("\l+")
|
11 |
+
"""
|
12 |
+
gt_dict = {
|
13 |
+
776: "JPC\\u\\u\\d+\\.png",
|
14 |
+
922: "WHS\\d_\\d+",
|
15 |
+
354: "\\u+",
|
16 |
+
523: "(\\u)+|\\.",
|
17 |
+
184: "\\.\\d+",
|
18 |
+
501: "u\\d\\d",
|
19 |
+
760: "\\u\\u",
|
20 |
+
49: "(\\u)+\\u\\d?",
|
21 |
+
732: "\\uR5\\d\\d",
|
22 |
+
450: "-\\d(\\.(\\d)+)?",
|
23 |
+
350: "\\u\\u",
|
24 |
+
467: "hu\\d(\\d|\\u)+",
|
25 |
+
622: "A(\\d|\\u)**",
|
26 |
+
476: "\\u+",
|
27 |
+
554: "\\u\\u",
|
28 |
+
940: "\\u\\u?",
|
29 |
+
496: "\\u\\u",
|
30 |
+
369: "\\u\\u\\u",
|
31 |
+
596: "\\u+",
|
32 |
+
720: "\\(\\d\\d\\d\\) \\d\\d\\d-\\d\\d\\d\\d",
|
33 |
+
53: "rec-\\d\\d\\d?-(org)|(dup-0)",
|
34 |
+
150: "N\\d\\d",
|
35 |
+
741: "#\\d\\d\\d",
|
36 |
+
18: "A|C-\\d+-\\d+",
|
37 |
+
589: "A(\\u|\\d)++",
|
38 |
+
666: "\\(\\d\\d\\d\\) \\d\\d\\d-\\d\\d\\d\\d",
|
39 |
+
581: "us13\\u\\d\\d",
|
40 |
+
299: "E07000\\d\\d\\d",
|
41 |
+
638: "\\l+\\d+\\l+\\d+",
|
42 |
+
364: "\\u\\u",
|
43 |
+
334: "-00:\\d\\d:\\d\\d.\\d",
|
44 |
+
38: "SRX89\\d+",
|
45 |
+
247: "'\\d\\d:\\d\\d:00'",
|
46 |
+
506: "(S|H)\\d+",
|
47 |
+
891: "(r|v)\\d?",
|
48 |
+
911: "KW-\\d+",
|
49 |
+
792: "\\d*\\u*",
|
50 |
+
508: "N000\\d+",
|
51 |
+
842: "-?\\d?\\d\\.\\d\\d%",
|
52 |
+
200: "\\u\\u",
|
53 |
+
694: "\\(\\d+\\)",
|
54 |
+
210: "(\\d(\\.\\d)?)|(--)",
|
55 |
+
298: "DS_25(\\u|\\d)+",
|
56 |
+
668: "\\u+",
|
57 |
+
939: "ms0\\d+",
|
58 |
+
944: "\\u+\\d?",
|
59 |
+
731: "ManH.0\\d\\d",
|
60 |
+
229: "\\u+(-\\u+)?",
|
61 |
+
28: "Y201\\d/\\d\\d\\d\\d",
|
62 |
+
374: "q000\\d(_000\\d)?",
|
63 |
+
819: "\\d*\\l*\\d*",
|
64 |
+
516: "-122.3\\d+",
|
65 |
+
417: "\\u\\uT\\uB",
|
66 |
+
660: "ENGL?\\d\\d\\d",
|
67 |
+
585: "M?\\u+",
|
68 |
+
325: "BUS M \\d\\d\\d.*",
|
69 |
+
823: "\\u\\u\\u",
|
70 |
+
515: "L|\\u - (\\?\\?)|(\\d?\\d\\.\\d lbs\\.)",
|
71 |
+
864: "\\u+",
|
72 |
+
359: "MAM\\.OSBS\\.201\\d\\.\\d\\d",
|
73 |
+
594: "(\\u|\\d)+( (\\u|\\d)+)*",
|
74 |
+
788: "-\\d(,\\d+)?",
|
75 |
+
188: "cat\\. \\d\\d",
|
76 |
+
355: ".+",
|
77 |
+
799: "\\u\\d\\d",
|
78 |
+
902: "\\u\\d\\d",
|
79 |
+
920: "A\\.\\d\\d",
|
80 |
+
330: "Resp\\d\\d",
|
81 |
+
396: "\\u+(( |/)\\u+)?",
|
82 |
+
393: "US $ \\d\\.\\d\\d",
|
83 |
+
680: "Z:-?0\\.\\d\\d",
|
84 |
+
744: "t1_cv(\\l|\\d)+",
|
85 |
+
461: "(\\u|\\l)+\\d+",
|
86 |
+
631: "$\\d+\\.\\d+",
|
87 |
+
195: "(OLE)?\\d+",
|
88 |
+
693: "\\u",
|
89 |
+
577: "EFO_000\\d+",
|
90 |
+
392: "$\\d+(,\\d\\d\\d)*\\.00",
|
91 |
+
688: "\\u+( \\u+)*",
|
92 |
+
816: "\\u\\u\\u",
|
93 |
+
489: "UK\\u\\d",
|
94 |
+
251: "\\l\\l\\l",
|
95 |
+
653: "C\\d+",
|
96 |
+
769: "(\\u|\\l|\\d|-)+\\d+",
|
97 |
+
991: "Q\\d-201\\d",
|
98 |
+
342: "\\u\\u\\d\\d\\d\\d",
|
99 |
+
308: "\\u\\u\\u\\u",
|
100 |
+
136: "IMPC_\\u\\u\\u_\\d\\d\\d_\\d\\d\\d",
|
101 |
+
327: "#\\d+((/|-)\\d+)*",
|
102 |
+
981: "\\u\\u\\u",
|
103 |
+
892: "(.|\\l)*",
|
104 |
+
375: "P\\u\\.\\d\\d\\d\\d\\.\\d\\d\\d",
|
105 |
+
499: "A000\\d+",
|
106 |
+
474: "\\u+",
|
107 |
+
50: "V06\\d+",
|
108 |
+
381: "F?\\d+",
|
109 |
+
883: "-79.\\d+",
|
110 |
+
173: "(\\u|\\l)+\\d+",
|
111 |
+
147: "\\u\\u\\u-\\u\\u\\u",
|
112 |
+
419: "\\u\\u",
|
113 |
+
961: "-?\\d\\.\\d*",
|
114 |
+
148: "Q\\d\\d",
|
115 |
+
975: "(\\d|\\u)+",
|
116 |
+
79: "\\d+(,\\d\\d\\d)+",
|
117 |
+
775: "\\u\\l\\l \\d+ \\d\\d\\d\\d",
|
118 |
+
774: "FOS\\d\\d+",
|
119 |
+
561: ".+",
|
120 |
+
509: "S000\\d+",
|
121 |
+
494: "S1900\\d+",
|
122 |
+
119: "$\\d\\d(,\\d\\d\\d)+",
|
123 |
+
29: "(\\u|\\l|\\d)+",
|
124 |
+
121: "(\\d|\\u|\\.|/|\\(|\\))+",
|
125 |
+
61: "R \\d\\d\\d.\\d\\d",
|
126 |
+
871: "-0.7\\d+",
|
127 |
+
639: "\\u+?\\d+",
|
128 |
+
729: "COMISARIA \\d\\d",
|
129 |
+
193: "\\u\\d\\d",
|
130 |
+
752: "(.*|\\u\\.?)+",
|
131 |
+
17: "$\\d.\\d\\d",
|
132 |
+
914: "R\\d\\d\\d\\d",
|
133 |
+
510: "P\\d000\\d\\d\\d\\d",
|
134 |
+
443: "(W|L) \\d-\\d+",
|
135 |
+
20: "MDEL\\d\\d?\\.\\d\\l",
|
136 |
+
64: "c04p0100(\\l|\\d)",
|
137 |
+
301: "(\\u|\\d)+(-(\\u|\\d)+)*",
|
138 |
+
664: "N\\d",
|
139 |
+
493: "[0\\.0\\d+]",
|
140 |
+
765: "-?\\d\\.\\d+( \\(0\\.\\d+\\))?"
|
141 |
+
|
142 |
+
|
143 |
+
}
|
144 |
+
badRegexTasks = {
|
145 |
+
"Data column no. 922",
|
146 |
+
"Data column no. 184",
|
147 |
+
"Data column no. 467",
|
148 |
+
"Data column no. 476",
|
149 |
+
"Data column no. 150",
|
150 |
+
"Data column no. 299",
|
151 |
+
"Data column no. 334",
|
152 |
+
"Data column no. 493",
|
153 |
+
"Data column no. 891",
|
154 |
+
"Data column no. 792",
|
155 |
+
"Data column no. 765",
|
156 |
+
"Data column no. 944",
|
157 |
+
"Data column no. 374",
|
158 |
+
"Data column no. 660",
|
159 |
+
"Data column no. 188",
|
160 |
+
"Data column no. 920",
|
161 |
+
"Data column no. 330",
|
162 |
+
"Data column no. 396",
|
163 |
+
"Data column no. 680",
|
164 |
+
"Data column no. 769",
|
165 |
+
"Data column no. 308",
|
166 |
+
"Data column no. 375",
|
167 |
+
"Data column no. 474",
|
168 |
+
"Data column no. 79",
|
169 |
+
"Data column no. 871",
|
170 |
+
"Data column no. 729",
|
171 |
+
"Data column no. 664",
|
172 |
+
}
|
dreamcoder/domains/regex/main.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# analog of list.py for regex tasks. Responsible for actually running the task.
|
2 |
+
|
3 |
+
from dreamcoder.domains.regex.makeRegexTasks import makeOldTasks, makeLongTasks, makeShortTasks, makeWordTasks, makeNumberTasks, makeHandPickedTasks, makeNewTasks, makeNewNumberTasks
|
4 |
+
from dreamcoder.domains.regex.regexPrimitives import basePrimitives, altPrimitives, easyWordsPrimitives, alt2Primitives, concatPrimitives, reducedConcatPrimitives, strConstConcatPrimitives, PRC
|
5 |
+
from dreamcoder.dreamcoder import explorationCompression, Task
|
6 |
+
from dreamcoder.grammar import Grammar
|
7 |
+
from dreamcoder.likelihoodModel import add_cutoff_values, add_string_constants
|
8 |
+
from dreamcoder.program import Abstraction, Application
|
9 |
+
from dreamcoder.type import tpregex
|
10 |
+
from dreamcoder.utilities import eprint, flatten, testTrainSplit, POSITIVEINFINITY
|
11 |
+
|
12 |
+
import random
|
13 |
+
import math
|
14 |
+
import pregex as pre
|
15 |
+
import os
|
16 |
+
|
17 |
+
try:
|
18 |
+
from dreamcoder.recognition import RecurrentFeatureExtractor, JSONFeatureExtractor
|
19 |
+
class LearnedFeatureExtractor(RecurrentFeatureExtractor):
|
20 |
+
H = 64
|
21 |
+
special = 'regex'
|
22 |
+
|
23 |
+
def tokenize(self, examples):
|
24 |
+
def sanitize(l): return [z if z in self.lexicon else "?"
|
25 |
+
for z_ in l
|
26 |
+
for z in (z_ if isinstance(z_, list) else [z_])]
|
27 |
+
|
28 |
+
tokenized = []
|
29 |
+
for xs, y in examples:
|
30 |
+
if isinstance(y, list):
|
31 |
+
y = ["LIST_START"] + y + ["LIST_END"]
|
32 |
+
else:
|
33 |
+
y = [y]
|
34 |
+
y = sanitize(y)
|
35 |
+
if len(y) > self.maximumLength:
|
36 |
+
return None
|
37 |
+
|
38 |
+
serializedInputs = []
|
39 |
+
for xi, x in enumerate(xs):
|
40 |
+
if isinstance(x, list):
|
41 |
+
x = ["LIST_START"] + x + ["LIST_END"]
|
42 |
+
else:
|
43 |
+
x = [x]
|
44 |
+
x = sanitize(x)
|
45 |
+
if len(x) > self.maximumLength:
|
46 |
+
return None
|
47 |
+
serializedInputs.append(x)
|
48 |
+
|
49 |
+
tokenized.append((tuple(serializedInputs), y))
|
50 |
+
|
51 |
+
return tokenized
|
52 |
+
|
53 |
+
def __init__(self, tasks, testingTasks=[], cuda=False):
|
54 |
+
self.lexicon = set(flatten((t.examples for t in tasks + testingTasks), abort=lambda x: isinstance(
|
55 |
+
x, str))).union({"LIST_START", "LIST_END", "?"})
|
56 |
+
|
57 |
+
self.num_examples_list = [len(t.examples) for t in tasks]
|
58 |
+
|
59 |
+
# Calculate the maximum length
|
60 |
+
self.maximumLength = POSITIVEINFINITY
|
61 |
+
self.maximumLength = max(len(l)
|
62 |
+
for t in tasks + testingTasks
|
63 |
+
for xs, y in self.tokenize(t.examples)
|
64 |
+
for l in [y] + [x for x in xs])
|
65 |
+
|
66 |
+
super(
|
67 |
+
LearnedFeatureExtractor,
|
68 |
+
self).__init__(
|
69 |
+
lexicon=list(
|
70 |
+
self.lexicon),
|
71 |
+
tasks=tasks,
|
72 |
+
cuda=cuda,
|
73 |
+
H=self.H,
|
74 |
+
bidirectional=True)
|
75 |
+
self.parallelTaskOfProgram = False
|
76 |
+
|
77 |
+
|
78 |
+
def taskOfProgram(self, p, t):
|
79 |
+
#raise NotImplementedError
|
80 |
+
num_examples = random.choice(self.num_examples_list)
|
81 |
+
|
82 |
+
p = p.visit(ConstantInstantiateVisitor.SINGLE)
|
83 |
+
|
84 |
+
preg = p.evaluate([])(pre.String(""))
|
85 |
+
t = Task("Helm", t, [((), list(preg.sample())) for _ in range(num_examples) ])
|
86 |
+
return t
|
87 |
+
except: pass
|
88 |
+
#in init: loop over tasks, save lengths,
|
89 |
+
|
90 |
+
|
91 |
+
class ConstantInstantiateVisitor(object):
|
92 |
+
def __init__(self):
|
93 |
+
self.regexes = [
|
94 |
+
pre.create(".+"),
|
95 |
+
pre.create("\d+"),
|
96 |
+
pre.create("\w+"),
|
97 |
+
pre.create("\s+"),
|
98 |
+
pre.create("\\u+"),
|
99 |
+
pre.create("\l+")]
|
100 |
+
|
101 |
+
def primitive(self, e):
|
102 |
+
if e.name == "r_const":
|
103 |
+
#return Primitive("STRING", e.tp, random.choice(self.words))
|
104 |
+
s = random.choice(self.regexes).sample() #random string const
|
105 |
+
s = pre.String(s)
|
106 |
+
e.value = PRC(s,arity=0)
|
107 |
+
return e
|
108 |
+
|
109 |
+
def invented(self, e): return e.body.visit(self)
|
110 |
+
|
111 |
+
def index(self, e): return e
|
112 |
+
|
113 |
+
def application(self, e):
|
114 |
+
return Application(e.f.visit(self), e.x.visit(self))
|
115 |
+
|
116 |
+
def abstraction(self, e):
|
117 |
+
return Abstraction(e.body.visit(self))
|
118 |
+
#TODO fix
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
class MyJSONFeatureExtractor(JSONFeatureExtractor):
|
124 |
+
N_EXAMPLES = 5
|
125 |
+
|
126 |
+
def _featuresOfProgram(self, program, tp):
|
127 |
+
try:
|
128 |
+
preg = program.evaluate([])
|
129 |
+
# if 'left_paren' in program.show(False):
|
130 |
+
#eprint("string_pregex:", string_pregex)
|
131 |
+
#eprint("string_pregex:", string_pregex)
|
132 |
+
|
133 |
+
except IndexError:
|
134 |
+
# free variable
|
135 |
+
return None
|
136 |
+
except Exception as e:
|
137 |
+
eprint("Exception during evaluation:", e)
|
138 |
+
if "Attempt to evaluate fragment variable" in e:
|
139 |
+
eprint("program (bc fragment error)", program)
|
140 |
+
return None
|
141 |
+
|
142 |
+
examples = []
|
143 |
+
|
144 |
+
for _ in range(self.N_EXAMPLES * 5): # oh this is arbitrary ig
|
145 |
+
|
146 |
+
try:
|
147 |
+
y = preg.sample() # TODO
|
148 |
+
|
149 |
+
#this line should keep inputs short, so that helmholtzbatch can be large
|
150 |
+
#allows it to try other samples
|
151 |
+
#(Could also return None off the bat... idk which is better)
|
152 |
+
#if len(y) > 20:
|
153 |
+
# continue
|
154 |
+
#eprint(tp, program, x, y)
|
155 |
+
examples.append(y)
|
156 |
+
except BaseException:
|
157 |
+
continues
|
158 |
+
if len(examples) >= self.N_EXAMPLES:
|
159 |
+
break
|
160 |
+
else:
|
161 |
+
return None
|
162 |
+
return examples # changed to list_features(examples) from examples
|
163 |
+
|
164 |
+
|
165 |
+
def regex_options(parser):
|
166 |
+
parser.add_argument("--maxTasks", type=int,
|
167 |
+
default=500,
|
168 |
+
help="truncate tasks to fit within this boundary")
|
169 |
+
parser.add_argument(
|
170 |
+
"--maxExamples",
|
171 |
+
type=int,
|
172 |
+
default=10,
|
173 |
+
help="truncate number of examples per task to fit within this boundary")
|
174 |
+
parser.add_argument("--tasks",
|
175 |
+
default="long",
|
176 |
+
help="which tasks to use",
|
177 |
+
choices=["old", "short", "long", "words", "number", "handpicked", "new", "newNumber"])
|
178 |
+
parser.add_argument("--primitives",
|
179 |
+
default="concat",
|
180 |
+
help="Which primitive set to use",
|
181 |
+
choices=["base", "alt1", "easyWords", "alt2", "concat", "reduced", "strConst"])
|
182 |
+
parser.add_argument("--extractor", type=str,
|
183 |
+
choices=["hand", "deep", "learned", "json"],
|
184 |
+
default="learned") # if i switch to json it breaks
|
185 |
+
parser.add_argument("--split", metavar="TRAIN_RATIO",
|
186 |
+
type=float,
|
187 |
+
default=0.8,
|
188 |
+
help="split test/train")
|
189 |
+
parser.add_argument("-H", "--hidden", type=int,
|
190 |
+
default=256,
|
191 |
+
help="number of hidden units")
|
192 |
+
parser.add_argument("--likelihoodModel",
|
193 |
+
default="probabilistic",
|
194 |
+
help="likelihood Model",
|
195 |
+
choices=["probabilistic", "all-or-nothing"])
|
196 |
+
parser.add_argument("--topk_use_map",
|
197 |
+
dest="topk_use_only_likelihood",
|
198 |
+
action="store_false")
|
199 |
+
parser.add_argument("--debug",
|
200 |
+
dest="debug",
|
201 |
+
action="store_true")
|
202 |
+
parser.add_argument("--ll_cutoff",
|
203 |
+
dest="use_ll_cutoff",
|
204 |
+
nargs='*',
|
205 |
+
default=False,
|
206 |
+
help="use ll cutoff for training tasks (for probabilistic likelihood model only). default is False,")
|
207 |
+
parser.add_argument("--use_str_const",
|
208 |
+
action="store_true",
|
209 |
+
help="use string constants")
|
210 |
+
|
211 |
+
"""parser.add_argument("--stardecay",
|
212 |
+
type=float,
|
213 |
+
dest="stardecay",
|
214 |
+
default=0.5,
|
215 |
+
help="p value for kleenestar and plus")"""
|
216 |
+
|
217 |
+
# Lucas recommends putting a struct with the definitions of the primitives here.
|
218 |
+
# TODO:
|
219 |
+
# Build likelihood funciton
|
220 |
+
# modify NN
|
221 |
+
# make primitives
|
222 |
+
# make tasks
|
223 |
+
|
224 |
+
|
225 |
+
def main(args):
|
226 |
+
"""
|
227 |
+
Takes the return value of the `commandlineArguments()` function as input and
|
228 |
+
trains/tests the model on regular expressions.
|
229 |
+
"""
|
230 |
+
#for dreaming
|
231 |
+
|
232 |
+
#parse use_ll_cutoff
|
233 |
+
use_ll_cutoff = args.pop('use_ll_cutoff')
|
234 |
+
if not use_ll_cutoff is False:
|
235 |
+
|
236 |
+
#if use_ll_cutoff is a list of strings, then train_ll_cutoff and train_ll_cutoff
|
237 |
+
#will be tuples of that string followed by the actual model
|
238 |
+
|
239 |
+
if len(use_ll_cutoff) == 1:
|
240 |
+
train_ll_cutoff = use_ll_cutoff[0] # make_cutoff_model(use_ll_cutoff[0], tasks))
|
241 |
+
test_ll_cutoff = use_ll_cutoff[0] # make_cutoff_model(use_ll_cutoff[0], tasks))
|
242 |
+
else:
|
243 |
+
assert len(use_ll_cutoff) == 2
|
244 |
+
train_ll_cutoff = use_ll_cutoff[0] #make_cutoff_model(use_ll_cutoff[0], tasks))
|
245 |
+
test_ll_cutoff = use_ll_cutoff[1] #make_cutoff_model(use_ll_cutoff[1], tasks))
|
246 |
+
else:
|
247 |
+
train_ll_cutoff = None
|
248 |
+
test_ll_cutoff = None
|
249 |
+
|
250 |
+
|
251 |
+
regexTasks = {"old": makeOldTasks,
|
252 |
+
"short": makeShortTasks,
|
253 |
+
"long": makeLongTasks,
|
254 |
+
"words": makeWordTasks,
|
255 |
+
"number": makeNumberTasks,
|
256 |
+
"handpicked": makeHandPickedTasks,
|
257 |
+
"new": makeNewTasks,
|
258 |
+
"newNumber": makeNewNumberTasks
|
259 |
+
}[args.pop("tasks")]
|
260 |
+
|
261 |
+
tasks = regexTasks() # TODO
|
262 |
+
eprint("Generated", len(tasks), "tasks")
|
263 |
+
|
264 |
+
maxTasks = args.pop("maxTasks")
|
265 |
+
if len(tasks) > maxTasks:
|
266 |
+
eprint("Unwilling to handle {} tasks, truncating..".format(len(tasks)))
|
267 |
+
seed = 42 # previously this was hardcoded and never changed
|
268 |
+
random.seed(seed)
|
269 |
+
random.shuffle(tasks)
|
270 |
+
del tasks[maxTasks:]
|
271 |
+
|
272 |
+
maxExamples = args.pop("maxExamples")
|
273 |
+
|
274 |
+
|
275 |
+
split = args.pop("split")
|
276 |
+
test, train = testTrainSplit(tasks, split)
|
277 |
+
eprint("Split tasks into %d/%d test/train" % (len(test), len(train)))
|
278 |
+
|
279 |
+
|
280 |
+
test = add_cutoff_values(test, test_ll_cutoff)
|
281 |
+
train = add_cutoff_values(train, train_ll_cutoff)
|
282 |
+
eprint("added cutoff values to tasks, train: ", train_ll_cutoff, ", test:", test_ll_cutoff )
|
283 |
+
|
284 |
+
|
285 |
+
if args.pop("use_str_const"):
|
286 |
+
assert args["primitives"] == "strConst" or args["primitives"] == "reduced"
|
287 |
+
ConstantInstantiateVisitor.SINGLE = \
|
288 |
+
ConstantInstantiateVisitor()
|
289 |
+
test = add_string_constants(test)
|
290 |
+
train = add_string_constants(train)
|
291 |
+
eprint("added string constants to test and train")
|
292 |
+
|
293 |
+
for task in test + train:
|
294 |
+
if len(task.examples) > maxExamples:
|
295 |
+
task.examples = task.examples[:maxExamples]
|
296 |
+
|
297 |
+
task.specialTask = ("regex", {"cutoff": task.ll_cutoff, "str_const": task.str_const})
|
298 |
+
task.examples = [(xs, [y for y in ys ])
|
299 |
+
for xs,ys in task.examples ]
|
300 |
+
task.maxParameters = 1
|
301 |
+
|
302 |
+
# from list stuff
|
303 |
+
primtype = args.pop("primitives")
|
304 |
+
prims = {"base": basePrimitives,
|
305 |
+
"alt1": altPrimitives,
|
306 |
+
"alt2": alt2Primitives,
|
307 |
+
"easyWords": easyWordsPrimitives,
|
308 |
+
"concat": concatPrimitives,
|
309 |
+
"reduced": reducedConcatPrimitives,
|
310 |
+
"strConst": strConstConcatPrimitives
|
311 |
+
}[primtype]
|
312 |
+
|
313 |
+
extractor = {
|
314 |
+
"learned": LearnedFeatureExtractor,
|
315 |
+
"json": MyJSONFeatureExtractor
|
316 |
+
}[args.pop("extractor")]
|
317 |
+
|
318 |
+
extractor.H = args.pop("hidden")
|
319 |
+
|
320 |
+
#stardecay = args.stardecay
|
321 |
+
#stardecay = args.pop('stardecay')
|
322 |
+
#decaystr = 'd' + str(stardecay)
|
323 |
+
import datetime
|
324 |
+
|
325 |
+
timestamp = datetime.datetime.now().isoformat()
|
326 |
+
outputDirectory = "experimentOutputs/regex/%s"%timestamp
|
327 |
+
os.system("mkdir -p %s"%outputDirectory)
|
328 |
+
|
329 |
+
args.update({
|
330 |
+
"featureExtractor": extractor,
|
331 |
+
"outputPrefix": "%s/regex"%(outputDirectory),
|
332 |
+
"evaluationTimeout": 0.005,
|
333 |
+
"topk_use_only_likelihood": True,
|
334 |
+
"maximumFrontier": 10,
|
335 |
+
"compressor": args.get("compressor","ocaml")
|
336 |
+
})
|
337 |
+
####
|
338 |
+
|
339 |
+
|
340 |
+
# use the
|
341 |
+
#prim_list = prims(stardecay)
|
342 |
+
prim_list = prims()
|
343 |
+
specials = ["r_kleene", "r_plus", "r_maybe", "r_alt", "r_concat"]
|
344 |
+
n_base_prim = len(prim_list) - len(specials)
|
345 |
+
|
346 |
+
productions = [
|
347 |
+
(math.log(0.5 / float(n_base_prim)),
|
348 |
+
prim) if prim.name not in specials else (
|
349 |
+
math.log(0.10),
|
350 |
+
prim) for prim in prim_list]
|
351 |
+
|
352 |
+
|
353 |
+
baseGrammar = Grammar.fromProductions(productions, continuationType=tpregex)
|
354 |
+
#baseGrammar = Grammar.uniform(prims())
|
355 |
+
|
356 |
+
#for i in range(100):
|
357 |
+
# eprint(baseGrammar.sample(tpregex))
|
358 |
+
|
359 |
+
#eprint(baseGrammar)
|
360 |
+
#explore
|
361 |
+
test_stuff = args.pop("debug")
|
362 |
+
if test_stuff:
|
363 |
+
eprint(baseGrammar)
|
364 |
+
eprint("sampled programs from prior:")
|
365 |
+
for i in range(100): #100
|
366 |
+
eprint(baseGrammar.sample(test[0].request,maximumDepth=1000))
|
367 |
+
eprint("""half the probability mass is on higher-order primitives.
|
368 |
+
Therefore half of enumerated programs should have more than one node.
|
369 |
+
However, we do not observe this.
|
370 |
+
Instead we see a very small fraction of programs have more than one node.
|
371 |
+
So something seems to be wrong with grammar.sample.
|
372 |
+
|
373 |
+
Furthermore: observe the large print statement above.
|
374 |
+
This prints the candidates for sampleDistribution in grammar.sample.
|
375 |
+
the first element of each tuple is the probability passed into sampleDistribution.
|
376 |
+
Half of the probability mass should be on the functions, but instead they are equally
|
377 |
+
weighted with the constants. If you look at the grammar above, this is an error!!!!
|
378 |
+
""")
|
379 |
+
assert False
|
380 |
+
|
381 |
+
del args["likelihoodModel"]
|
382 |
+
explorationCompression(baseGrammar, train,
|
383 |
+
testingTasks = test,
|
384 |
+
**args)
|
dreamcoder/domains/regex/makeRegexTasks.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dill
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
from string import printable
|
5 |
+
|
6 |
+
import sys
|
7 |
+
try:
|
8 |
+
from pregex import pregex
|
9 |
+
except:
|
10 |
+
print("Failure to load pregex. This is only acceptable if using pypy",file=sys.stderr)
|
11 |
+
|
12 |
+
from dreamcoder.task import Task
|
13 |
+
from dreamcoder.type import tpregex, arrow
|
14 |
+
from dreamcoder.utilities import get_data_dir
|
15 |
+
|
16 |
+
|
17 |
+
def makeOldTasks():
|
18 |
+
# a series of tasks
|
19 |
+
|
20 |
+
taskfile = os.path.join(get_data_dir(), 'data_filtered.json')
|
21 |
+
#task_list = pickle.load(open(taskfile, 'rb'))
|
22 |
+
|
23 |
+
with open(taskfile) as f:
|
24 |
+
file_contents = f.read()
|
25 |
+
task_list = json.loads(file_contents)
|
26 |
+
|
27 |
+
|
28 |
+
# if I were to just dump all of them:
|
29 |
+
regextasks = [
|
30 |
+
Task("Luke data column no." + str(i),
|
31 |
+
arrow(tpregex, tpregex),
|
32 |
+
[((), example) for example in task_list[i]]
|
33 |
+
) for i in range(len(task_list))]
|
34 |
+
|
35 |
+
|
36 |
+
""" regextasks = [
|
37 |
+
Task("length bool", arrow(none,tstr),
|
38 |
+
[((l,), len(l))
|
39 |
+
for _ in range(10)
|
40 |
+
for l in [[flip() for _ in range(randint(0,10)) ]] ]),
|
41 |
+
Task("length int", arrow(none,tstr),
|
42 |
+
[((l,), len(l))
|
43 |
+
for _ in range(10)
|
44 |
+
for l in [randomList()] ]),
|
45 |
+
]
|
46 |
+
"""
|
47 |
+
return regextasks # some list of tasks
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
def makeShortTasks():
|
53 |
+
|
54 |
+
#load new data:
|
55 |
+
|
56 |
+
taskfile = os.path.join(get_data_dir(), "regex_data_csv_900.p")
|
57 |
+
|
58 |
+
with open(taskfile, 'rb') as handle:
|
59 |
+
data = dill.load(handle)
|
60 |
+
|
61 |
+
tasklist = data[0][:100] #a list of indices
|
62 |
+
|
63 |
+
regextasks = [
|
64 |
+
Task("Data column no. " + str(i),
|
65 |
+
arrow(tpregex, tpregex),
|
66 |
+
[((), example) for example in task]
|
67 |
+
) for i, task in enumerate(tasklist)]
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
return regextasks
|
72 |
+
|
73 |
+
def makeLongTasks():
|
74 |
+
|
75 |
+
#load new data:
|
76 |
+
|
77 |
+
taskfile = os.path.join(get_data_dir(), "regex_data_csv_900.p")
|
78 |
+
|
79 |
+
with open(taskfile, 'rb') as handle:
|
80 |
+
data = dill.load(handle)
|
81 |
+
|
82 |
+
tasklist = data[0] #a list of indices
|
83 |
+
|
84 |
+
regextasks = [
|
85 |
+
Task("Data column no. " + str(i),
|
86 |
+
arrow(tpregex, tpregex),
|
87 |
+
[((), example) for example in task]
|
88 |
+
) for i, task in enumerate(tasklist)]
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
return regextasks
|
93 |
+
|
94 |
+
def makeWordTasks():
|
95 |
+
|
96 |
+
#load new data:
|
97 |
+
|
98 |
+
taskfile = os.path.join(get_data_dir(), "regex_data_csv_900.p")
|
99 |
+
|
100 |
+
with open(taskfile, 'rb') as handle:
|
101 |
+
data = dill.load(handle)
|
102 |
+
|
103 |
+
tasklist = data[0] #a list of indices
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
all_upper = [0, 2, 8, 9, 10, 11, 12, 17, 18, 19, 20, 22]
|
109 |
+
all_lower = [1]
|
110 |
+
|
111 |
+
# match_col(data[0],'\\u(\l+)')
|
112 |
+
one_capital_lower_plus = [144, 200, 241, 242, 247, 296, 390, 392, 444, 445, 481, 483, 485, 489, 493, 542, 549, 550, 581]
|
113 |
+
|
114 |
+
#match_col(data[0],'(\l ?)+')
|
115 |
+
lower_with_maybe_spaces = [1, 42, 47, 99, 100, 102, 201, 246, 248, 293, 294, 345, 437, 545, 590]
|
116 |
+
|
117 |
+
#match_col(data[0],'(\\u\l+ ?)+')
|
118 |
+
capital_then_lower_maybe_spaces = [144, 200, 241, 242, 247, 296, 390, 392, 395, 438, 444, 445, 481, 483, 484, 485, 487, 489, 493, 494, 542, 546, 549, 550, 578, 581, 582, 588, 591, 624, 629]
|
119 |
+
|
120 |
+
#match_col(data[0],'(\\u+ ?)+')
|
121 |
+
all_caps_spaces = [0, 2, 8, 9, 10, 11, 12, 17, 18, 19, 20, 22, 25, 26, 35, 36, 43, 45, 46, 49, 50, 52, 56, 59, 87, 89, 95, 101, 140, 147, 148, 149, 199, 332, 336, 397, 491, 492, 495, 580, 610]
|
122 |
+
|
123 |
+
#one_capital_and_lower = [566, 550, 549, 542, 505, 493, 494, 489, 488, 485, 483, 481, 445, 444, 438, 296, 241, 242, 200, ]
|
124 |
+
#all_lower_with_a_space = [545]
|
125 |
+
#all_lower_maybe_space = [534]
|
126 |
+
#one_capital_lower_maybe_spaces = [259, 262, 263, 264]
|
127 |
+
|
128 |
+
|
129 |
+
#full_list = test_list + train_list
|
130 |
+
train_list = []
|
131 |
+
full_list = all_upper + all_lower + one_capital_lower_plus + lower_with_maybe_spaces + capital_then_lower_maybe_spaces + all_caps_spaces
|
132 |
+
|
133 |
+
regextasks = [
|
134 |
+
Task("Data column no. " + str(i),
|
135 |
+
arrow(tpregex, tpregex),
|
136 |
+
[((), example) for example in task]
|
137 |
+
) for i, task in enumerate(tasklist) if i in full_list ]
|
138 |
+
|
139 |
+
for i in train_list:
|
140 |
+
regextasks[i].mustTrain = True
|
141 |
+
|
142 |
+
|
143 |
+
return regextasks
|
144 |
+
|
145 |
+
def makeNumberTasks():
|
146 |
+
|
147 |
+
#load new data:
|
148 |
+
|
149 |
+
taskfile = os.path.join(get_data_dir(), "regex_data_csv_900.p")
|
150 |
+
|
151 |
+
with open(taskfile, 'rb') as handle:
|
152 |
+
data = dill.load(handle)
|
153 |
+
|
154 |
+
tasklist = data[0] #a list of indices
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
#match_col(data[0],'\d*\.\d*')
|
159 |
+
raw_decimals = [121, 122, 163, 164, 165, 170, 172, 173, 175, 178, 218, 228, 230, 231, 252, 253,
|
160 |
+
254, 258, 259, 305, 320, 330, 334, 340, 348, 350, 351, 352, 353, 355, 357, 358, 361, 363, 364,
|
161 |
+
371, 380, 382, 409, 410, 411, 447, 448, 449, 450, 458, 469, 471, 533, 562, 564]
|
162 |
+
|
163 |
+
|
164 |
+
decimals_pos_neg_dollar = [3, 4, 5, 6, 7, 13, 16, 24, 27, 28, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40,
|
165 |
+
53, 54, 55, 57, 58, 60, 61, 63, 64, 65, 66, 68, 69, 70, 71, 73, 74, 77, 78, 80, 81, 103, 104, 105,
|
166 |
+
106, 107, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 121, 122, 123, 124, 125, 126, 128,
|
167 |
+
129, 131, 132, 134, 135, 139, 146, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
|
168 |
+
166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 180, 181, 182, 183, 184, 185, 186,
|
169 |
+
193, 194, 195, 204, 205, 207, 209, 210, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 223, 224,
|
170 |
+
225, 226, 227, 228, 229, 230, 231, 232, 249, 250, 251, 252, 253, 254, 255, 256, 258, 259, 260, 261,
|
171 |
+
263, 266, 267, 270, 271, 272, 277, 299, 301, 302, 305, 306, 307, 309, 312, 313, 315, 319, 320, 324,
|
172 |
+
326, 327, 330, 334, 340, 348, 350, 351, 352, 353, 354, 355, 356, 357, 358, 361, 362, 363, 364, 368,
|
173 |
+
371, 373, 377, 380, 382, 400, 401, 402, 403, 405, 406, 409, 410, 411, 413, 435, 439, 446, 447, 448,
|
174 |
+
449, 450, 451, 452, 453, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 469, 470, 471, 477,
|
175 |
+
498, 500, 502, 503, 507, 512, 518, 519, 520, 532, 533, 553, 554, 555, 556, 557, 558, 559, 560, 561,
|
176 |
+
562, 564, 565, 572, 577]
|
177 |
+
|
178 |
+
#match_col(data[0],'(\d*,?\d*)+')
|
179 |
+
commas = []
|
180 |
+
#match_col(data[0],'(\d*,?\d*)+')
|
181 |
+
commas_and_all = []
|
182 |
+
|
183 |
+
#full_list = test_list + train_list
|
184 |
+
train_list = []
|
185 |
+
full_list = decimals_pos_neg_dollar
|
186 |
+
|
187 |
+
regextasks = [
|
188 |
+
Task("Data column no. " + str(i),
|
189 |
+
arrow(tpregex, tpregex),
|
190 |
+
[((), example) for example in task]
|
191 |
+
) for i, task in enumerate(tasklist) if i in full_list ]
|
192 |
+
|
193 |
+
for i in train_list:
|
194 |
+
regextasks[i].mustTrain = True
|
195 |
+
|
196 |
+
|
197 |
+
return regextasks
|
198 |
+
|
199 |
+
|
200 |
+
def makeHandPickedTasks():
|
201 |
+
|
202 |
+
#load new data:
|
203 |
+
|
204 |
+
taskfile = os.path.join(get_data_dir(), "regex_data_csv_900.p")
|
205 |
+
|
206 |
+
with open(taskfile, 'rb') as handle:
|
207 |
+
data = dill.load(handle)
|
208 |
+
|
209 |
+
tasklist = data[0] #a list of indices
|
210 |
+
|
211 |
+
|
212 |
+
full_list = list(range(199)) + \
|
213 |
+
[209,218,222,223,224,225,226] + \
|
214 |
+
list(range(222,233)) + \
|
215 |
+
[235,237,238,239,243,244,245,252,253,254,255,257,258,259,260,261,264,265,269,272,274] + \
|
216 |
+
list(range(275,291)) + \
|
217 |
+
[295,297,300,303,304,305,306,310,311,312,314,315,316,320,321,323,327,329,330,333,334,335,337,338,339,340,341,342,343,344] + \
|
218 |
+
list(range(348,359)) + \
|
219 |
+
[361,369,373,379,380,382,387,403,405,407,408] + \
|
220 |
+
list(range(409,417)) + \
|
221 |
+
list(range(418,437)) + \
|
222 |
+
list(range(440,444)) + \
|
223 |
+
list(range(446,452)) + \
|
224 |
+
list(range(456,460)) + \
|
225 |
+
list(range(466,472)) + \
|
226 |
+
[503,504]
|
227 |
+
|
228 |
+
|
229 |
+
regextasks = [
|
230 |
+
Task("Data column no. " + str(i),
|
231 |
+
arrow(tpregex, tpregex),
|
232 |
+
[((), example) for example in task]
|
233 |
+
) for i, task in enumerate(tasklist) if i in full_list ]
|
234 |
+
|
235 |
+
#for i in train_list:
|
236 |
+
# regextasks[i].mustTrain = True
|
237 |
+
|
238 |
+
|
239 |
+
return regextasks
|
240 |
+
|
241 |
+
def makeNewTasks(include_only=None):
|
242 |
+
|
243 |
+
#load new data:
|
244 |
+
|
245 |
+
taskfile = os.path.join(get_data_dir(), "csv_filtered_all_background_novel.p")
|
246 |
+
|
247 |
+
with open(taskfile, 'rb') as handle:
|
248 |
+
data = dill.load(handle)
|
249 |
+
|
250 |
+
tasklist = data['background'] #a list of indices
|
251 |
+
|
252 |
+
if include_only:
|
253 |
+
regextasks = [
|
254 |
+
Task("Data column no. " + str(i),
|
255 |
+
arrow(tpregex, tpregex),
|
256 |
+
[((), example) for example in task['train']]
|
257 |
+
) for i, task in enumerate(tasklist) if i in include_only]
|
258 |
+
else:
|
259 |
+
regextasks = [
|
260 |
+
Task("Data column no. " + str(i),
|
261 |
+
arrow(tpregex, tpregex),
|
262 |
+
[((), example) for example in task['train']]
|
263 |
+
) for i, task in enumerate(tasklist)]
|
264 |
+
|
265 |
+
#for i in train_list:
|
266 |
+
# regextasks[i].mustTrain = True
|
267 |
+
|
268 |
+
return regextasks
|
269 |
+
REGEXTASKS = None
|
270 |
+
def regexHeldOutExamples(task, include_only=None):
|
271 |
+
|
272 |
+
#load new data:
|
273 |
+
global REGEXTASKS
|
274 |
+
if REGEXTASKS is None:
|
275 |
+
taskfile = os.path.join(get_data_dir(), "csv_filtered_all_background_novel.p")
|
276 |
+
|
277 |
+
with open(taskfile, 'rb') as handle:
|
278 |
+
data = dill.load(handle)
|
279 |
+
|
280 |
+
tasklist = data['background'] #a list of indices
|
281 |
+
|
282 |
+
if include_only:
|
283 |
+
regextasks = [
|
284 |
+
Task("Data column no. " + str(i),
|
285 |
+
arrow(tpregex, tpregex),
|
286 |
+
[((), example) for example in _task['test']]
|
287 |
+
) for i, _task in enumerate(tasklist) if i in include_only]
|
288 |
+
else:
|
289 |
+
regextasks = [
|
290 |
+
Task("Data column no. " + str(i),
|
291 |
+
arrow(tpregex, tpregex),
|
292 |
+
[((), example) for example in _task['test']]
|
293 |
+
) for i, _task in enumerate(tasklist)]
|
294 |
+
|
295 |
+
#for i in train_list:
|
296 |
+
# regextasks[i].mustTrain = True
|
297 |
+
|
298 |
+
REGEXTASKS = {t.name: t.examples for t in regextasks}
|
299 |
+
fullTask = REGEXTASKS[task.name]
|
300 |
+
return fullTask
|
301 |
+
|
302 |
+
|
303 |
+
|
304 |
+
def makeNewNumberTasks():
|
305 |
+
|
306 |
+
tasks = makeNewTasks()
|
307 |
+
numberTasks = [t for t in tasks if not any(p in ex for p in printable[10:62] for _, ex in t.examples)]
|
308 |
+
return numberTasks
|
309 |
+
|
310 |
+
|
311 |
+
# a helper function which takes a list of lists and sees which match a specific regex.
|
312 |
+
def match_col(dataset, rstring):
|
313 |
+
r = pregex.create(rstring)
|
314 |
+
matches = []
|
315 |
+
for i, col in enumerate(dataset):
|
316 |
+
score = sum([r.match(example) for example in col])
|
317 |
+
if score != float('-inf'):
|
318 |
+
matches.append(i)
|
319 |
+
return matches
|
320 |
+
|
321 |
+
if __name__ == "__main__":
|
322 |
+
import argparse
|
323 |
+
parser = argparse.ArgumentParser()
|
324 |
+
parser.add_argument("--include_only",
|
325 |
+
default=None,
|
326 |
+
nargs="+",
|
327 |
+
type=int)
|
328 |
+
args = parser.parse_args()
|
329 |
+
|
330 |
+
|
331 |
+
def show_tasks(dataset):
|
332 |
+
task_list = []
|
333 |
+
for task in dataset:
|
334 |
+
print(task.name)
|
335 |
+
print([example[1] for example in task.examples[:20]])
|
336 |
+
task_list.append([example[1] for example in task.examples])
|
337 |
+
return task_list
|
338 |
+
|
339 |
+
task = {"number": makeNumberTasks,
|
340 |
+
"words": makeWordTasks,
|
341 |
+
"all": makeLongTasks,
|
342 |
+
"new": makeNewTasks}['new']
|
343 |
+
|
344 |
+
|
345 |
+
x = show_tasks(task(args.include_only))
|
346 |
+
|
347 |
+
|
dreamcoder/domains/regex/regexPrimitives.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from dreamcoder.program import Primitive
|
3 |
+
from dreamcoder.grammar import Grammar
|
4 |
+
from dreamcoder.type import arrow, tpregex
|
5 |
+
from string import printable
|
6 |
+
|
7 |
+
try:
|
8 |
+
from pregex import pregex
|
9 |
+
except:
|
10 |
+
print("Failure to load pregex. This is only acceptable if using pypy",file=sys.stderr)
|
11 |
+
|
12 |
+
|
13 |
+
# evaluation to regular regex form. then I can unflatten using Luke's stuff.
|
14 |
+
|
15 |
+
|
16 |
+
def _kleene(x): return pregex.KleeneStar(x, p=0.25)
|
17 |
+
|
18 |
+
|
19 |
+
def _plus(x): return pregex.Plus(x, p=0.25)
|
20 |
+
|
21 |
+
|
22 |
+
def _maybe(x): return pregex.Maybe(x)
|
23 |
+
|
24 |
+
|
25 |
+
# maybe should be reversed#"(" + x + "|" + y + ")"
|
26 |
+
def _alt(x): return lambda y: pregex.Alt([x, y])
|
27 |
+
|
28 |
+
|
29 |
+
def _concat(x): return lambda y: pregex.Concat([x, y]) # "(" + x + y + ")"
|
30 |
+
|
31 |
+
|
32 |
+
#For sketch:
|
33 |
+
def _kleene_5(x): return pregex.KleeneStar(x)
|
34 |
+
|
35 |
+
def _plus_5(x): return pregex.Plus(x)
|
36 |
+
|
37 |
+
|
38 |
+
disallowed = [
|
39 |
+
("#", "hash"),
|
40 |
+
("!", "bang"),
|
41 |
+
("\"", "double_quote"),
|
42 |
+
("$", "dollar"),
|
43 |
+
("%", "percent"),
|
44 |
+
("&", "ampersand"),
|
45 |
+
("'", "single_quote"),
|
46 |
+
(")", "left_paren"),
|
47 |
+
("(", "right_paren"),
|
48 |
+
("*", "astrisk"),
|
49 |
+
("+", "plus"),
|
50 |
+
(",", "comma"),
|
51 |
+
("-", "dash"),
|
52 |
+
(".", "period"),
|
53 |
+
("/", "slash"),
|
54 |
+
(":", "colon"),
|
55 |
+
(";", "semicolon"),
|
56 |
+
("<", "less_than"),
|
57 |
+
("=", "equal"),
|
58 |
+
(">", "greater_than"),
|
59 |
+
("?", "question_mark"),
|
60 |
+
("@", "at"),
|
61 |
+
("[", "left_bracket"),
|
62 |
+
("\\", "backslash"),
|
63 |
+
("]", "right_bracket"),
|
64 |
+
("^", "carrot"),
|
65 |
+
("_", "underscore"),
|
66 |
+
("`", "backtick"),
|
67 |
+
("|", "bar"),
|
68 |
+
("}", "right_brace"),
|
69 |
+
("{", "left_brace"),
|
70 |
+
("~", "tilde"),
|
71 |
+
(" ", "space"),
|
72 |
+
("\t", "tab")
|
73 |
+
]
|
74 |
+
|
75 |
+
disallowed_list = [char for char, _ in disallowed]
|
76 |
+
|
77 |
+
class PRC(): #PregexContinuation
|
78 |
+
def __init__(self, f, arity=0, args=[]):
|
79 |
+
self.f = f
|
80 |
+
self.arity = arity
|
81 |
+
self.args = args
|
82 |
+
|
83 |
+
def __call__(self, pre):
|
84 |
+
|
85 |
+
if self.arity == len(self.args):
|
86 |
+
if self.arity == 0: return pregex.Concat([self.f, pre])
|
87 |
+
elif self.arity == 1: return pregex.Concat([self.f(*self.args), pre])
|
88 |
+
else: return pregex.Concat([self.f(self.args), pre]) #this line is bad, need brackets around input to f if f is Alt
|
89 |
+
else: return PRC(self.f, self.arity, args=self.args+[pre(pregex.String(""))])
|
90 |
+
|
91 |
+
|
92 |
+
def concatPrimitives():
|
93 |
+
return [Primitive("string_" + i, arrow(tpregex, tpregex), PRC(pregex.String(i))) for i in printable[:-4] if i not in disallowed_list
|
94 |
+
] + [
|
95 |
+
Primitive("string_" + name, arrow(tpregex, tpregex), PRC(pregex.String(char))) for char, name in disallowed
|
96 |
+
] + [
|
97 |
+
Primitive("r_dot", arrow(tpregex, tpregex), PRC(pregex.dot)),
|
98 |
+
Primitive("r_d", arrow(tpregex, tpregex), PRC(pregex.d)),
|
99 |
+
Primitive("r_s", arrow(tpregex, tpregex), PRC(pregex.s)),
|
100 |
+
Primitive("r_w", arrow(tpregex, tpregex), PRC(pregex.w)),
|
101 |
+
Primitive("r_l", arrow(tpregex, tpregex), PRC(pregex.l)),
|
102 |
+
Primitive("r_u", arrow(tpregex, tpregex), PRC(pregex.u)),
|
103 |
+
#todo
|
104 |
+
Primitive("r_kleene", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.KleeneStar,1)),
|
105 |
+
Primitive("r_plus", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Plus,1)),
|
106 |
+
Primitive("r_maybe", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Maybe,1)),
|
107 |
+
Primitive("r_alt", arrow(arrow(tpregex, tpregex) , arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Alt, 2)),
|
108 |
+
]
|
109 |
+
|
110 |
+
def strConstConcatPrimitives():
|
111 |
+
return [Primitive("string_" + i, arrow(tpregex, tpregex), PRC(pregex.String(i))) for i in printable[:-4] if i not in disallowed_list
|
112 |
+
] + [
|
113 |
+
Primitive("string_" + name, arrow(tpregex, tpregex), PRC(pregex.String(char))) for char, name in disallowed
|
114 |
+
] + [
|
115 |
+
Primitive("r_dot", arrow(tpregex, tpregex), PRC(pregex.dot)),
|
116 |
+
Primitive("r_d", arrow(tpregex, tpregex), PRC(pregex.d)),
|
117 |
+
Primitive("r_s", arrow(tpregex, tpregex), PRC(pregex.s)),
|
118 |
+
Primitive("r_w", arrow(tpregex, tpregex), PRC(pregex.w)),
|
119 |
+
Primitive("r_l", arrow(tpregex, tpregex), PRC(pregex.l)),
|
120 |
+
Primitive("r_u", arrow(tpregex, tpregex), PRC(pregex.u)),
|
121 |
+
#todo
|
122 |
+
Primitive("r_kleene", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.KleeneStar,1)),
|
123 |
+
Primitive("r_plus", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Plus,1)),
|
124 |
+
Primitive("r_maybe", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Maybe,1)),
|
125 |
+
Primitive("r_alt", arrow(arrow(tpregex, tpregex) , arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Alt, 2)),
|
126 |
+
] + [
|
127 |
+
Primitive("r_const", arrow(tpregex, tpregex), None)
|
128 |
+
]
|
129 |
+
|
130 |
+
|
131 |
+
def reducedConcatPrimitives():
|
132 |
+
#uses strConcat!!
|
133 |
+
#[Primitive("empty_string", arrow(tpregex, tpregex), PRC(pregex.String("")))
|
134 |
+
#] + [
|
135 |
+
return [Primitive("string_" + i, arrow(tpregex, tpregex), PRC(pregex.String(i))) for i in printable[:-4] if i not in disallowed_list
|
136 |
+
] + [
|
137 |
+
Primitive("string_" + name, arrow(tpregex, tpregex), PRC(pregex.String(char))) for char, name in disallowed
|
138 |
+
] + [
|
139 |
+
Primitive("r_dot", arrow(tpregex, tpregex), PRC(pregex.dot)),
|
140 |
+
Primitive("r_d", arrow(tpregex, tpregex), PRC(pregex.d)),
|
141 |
+
Primitive("r_s", arrow(tpregex, tpregex), PRC(pregex.s)),
|
142 |
+
#Primitive("r_w", arrow(tpregex, tpregex), PRC(pregex.w)),
|
143 |
+
Primitive("r_l", arrow(tpregex, tpregex), PRC(pregex.l)),
|
144 |
+
Primitive("r_u", arrow(tpregex, tpregex), PRC(pregex.u)),
|
145 |
+
#todo
|
146 |
+
Primitive("r_kleene", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.KleeneStar,1)),
|
147 |
+
#Primitive("r_plus", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Plus,1)),
|
148 |
+
Primitive("r_maybe", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Maybe,1)),
|
149 |
+
Primitive("r_alt", arrow(arrow(tpregex, tpregex) , arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Alt, 2)),
|
150 |
+
] + [
|
151 |
+
Primitive("r_const", arrow(tpregex, tpregex), None)
|
152 |
+
]
|
153 |
+
|
154 |
+
|
155 |
+
def sketchPrimitives():
|
156 |
+
return [Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
|
157 |
+
] + [
|
158 |
+
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
|
159 |
+
] + [
|
160 |
+
Primitive("r_dot", tpregex, pregex.dot),
|
161 |
+
Primitive("r_d", tpregex, pregex.d),
|
162 |
+
Primitive("r_s", tpregex, pregex.s),
|
163 |
+
Primitive("r_w", tpregex, pregex.w),
|
164 |
+
Primitive("r_l", tpregex, pregex.l),
|
165 |
+
Primitive("r_u", tpregex, pregex.u),
|
166 |
+
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene_5),
|
167 |
+
Primitive("r_plus", arrow(tpregex, tpregex), _plus_5),
|
168 |
+
Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
|
169 |
+
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
|
170 |
+
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
|
171 |
+
]
|
172 |
+
|
173 |
+
def basePrimitives():
|
174 |
+
return [Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
|
175 |
+
] + [
|
176 |
+
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
|
177 |
+
] + [
|
178 |
+
Primitive("r_dot", tpregex, pregex.dot),
|
179 |
+
Primitive("r_d", tpregex, pregex.d),
|
180 |
+
Primitive("r_s", tpregex, pregex.s),
|
181 |
+
Primitive("r_w", tpregex, pregex.w),
|
182 |
+
Primitive("r_l", tpregex, pregex.l),
|
183 |
+
Primitive("r_u", tpregex, pregex.u),
|
184 |
+
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
|
185 |
+
Primitive("r_plus", arrow(tpregex, tpregex), _plus),
|
186 |
+
Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
|
187 |
+
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
|
188 |
+
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
|
189 |
+
]
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
def altPrimitives():
|
194 |
+
return [
|
195 |
+
Primitive("empty_string", tpregex, pregex.String(""))
|
196 |
+
] + [
|
197 |
+
Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
|
198 |
+
] + [
|
199 |
+
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
|
200 |
+
] + [
|
201 |
+
Primitive("r_dot", tpregex, pregex.dot),
|
202 |
+
Primitive("r_d", tpregex, pregex.d),
|
203 |
+
Primitive("r_s", tpregex, pregex.s),
|
204 |
+
Primitive("r_w", tpregex, pregex.w),
|
205 |
+
Primitive("r_l", tpregex, pregex.l),
|
206 |
+
Primitive("r_u", tpregex, pregex.u),
|
207 |
+
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
|
208 |
+
#Primitive("r_plus", arrow(tpregex, tpregex), _plus),
|
209 |
+
Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
|
210 |
+
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
|
211 |
+
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
|
212 |
+
]
|
213 |
+
|
214 |
+
def alt2Primitives():
|
215 |
+
return [
|
216 |
+
Primitive("empty_string", tpregex, pregex.String(""))
|
217 |
+
] + [
|
218 |
+
Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
|
219 |
+
] + [
|
220 |
+
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
|
221 |
+
] + [
|
222 |
+
Primitive("r_dot", tpregex, pregex.dot),
|
223 |
+
Primitive("r_d", tpregex, pregex.d),
|
224 |
+
Primitive("r_s", tpregex, pregex.s),
|
225 |
+
Primitive("r_w", tpregex, pregex.w),
|
226 |
+
Primitive("r_l", tpregex, pregex.l),
|
227 |
+
Primitive("r_u", tpregex, pregex.u),
|
228 |
+
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
|
229 |
+
#Primitive("r_plus", arrow(tpregex, tpregex), _plus),
|
230 |
+
#Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
|
231 |
+
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
|
232 |
+
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
|
233 |
+
]
|
234 |
+
|
235 |
+
def easyWordsPrimitives():
|
236 |
+
return [
|
237 |
+
Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[10:62] if i not in disallowed_list
|
238 |
+
] + [
|
239 |
+
Primitive("r_d", tpregex, pregex.d),
|
240 |
+
Primitive("r_s", tpregex, pregex.s),
|
241 |
+
#Primitive("r_w", tpregex, pregex.w),
|
242 |
+
Primitive("r_l", tpregex, pregex.l),
|
243 |
+
Primitive("r_u", tpregex, pregex.u),
|
244 |
+
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
|
245 |
+
Primitive("r_plus", arrow(tpregex, tpregex), _plus),
|
246 |
+
Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
|
247 |
+
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
|
248 |
+
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
|
249 |
+
]
|
250 |
+
|
251 |
+
|
252 |
+
#def _wrapper(x): return lambda y: y
|
253 |
+
|
254 |
+
#specials = [".","*","+","?","|"]
|
255 |
+
"""
|
256 |
+
>>> import pregex as pre
|
257 |
+
>>> abc = pre.CharacterClass("abc", [0.1, 0.1, 0.8], name="MyConcept")
|
258 |
+
>>> abc.sample()
|
259 |
+
'b'
|
260 |
+
>>> abc.sample()
|
261 |
+
'c'
|
262 |
+
>>> abc.sample()
|
263 |
+
'c'
|
264 |
+
>>> abc.match("c")
|
265 |
+
-0.2231435513142097
|
266 |
+
>>> abc.match("a")
|
267 |
+
-2.3025850929940455
|
268 |
+
>>> abc
|
269 |
+
MyConcept
|
270 |
+
>>> x = pre.KleeneStar(abc)
|
271 |
+
>>> x.match("aabbac")
|
272 |
+
-16.58809928020405
|
273 |
+
>>> x.sample()
|
274 |
+
''
|
275 |
+
>>> x.sample()
|
276 |
+
''
|
277 |
+
>>> x.sample()
|
278 |
+
'cbcacc'
|
279 |
+
>>> x
|
280 |
+
(KleeneStar 0.5 MyConcept)
|
281 |
+
>>> str(x)
|
282 |
+
'MyConcept*'
|
283 |
+
"""
|
284 |
+
|
285 |
+
|
286 |
+
def emp_dot(corpus): return pregex.CharacterClass(printable[:-4], emp_distro_from_corpus(corpus, printable[:-4]), name=".")
|
287 |
+
|
288 |
+
def emp_d(corpus): return pregex.CharacterClass(printable[:10], emp_distro_from_corpus(corpus, printable[:10]), name="\\d")
|
289 |
+
|
290 |
+
#emp_s = pre.CharacterClass(slist, [], name="emp\\s") #may want to forgo this one.
|
291 |
+
|
292 |
+
def emp_dot_no_letter(corpus): return pregex.CharacterClass(printable[:10]+printable[62:], emp_distro_from_corpus(corpus, printable[:10]+printable[62:]), name=".")
|
293 |
+
|
294 |
+
def emp_w(corpus): return pregex.CharacterClass(printable[:62], emp_distro_from_corpus(corpus, printable[:62]), name="\\w")
|
295 |
+
|
296 |
+
def emp_l(corpus): return pregex.CharacterClass(printable[10:36], emp_distro_from_corpus(corpus, printable[10:36]), name="\\l")
|
297 |
+
|
298 |
+
def emp_u(corpus): return pregex.CharacterClass(printable[36:62], emp_distro_from_corpus(corpus, printable[36:62]), name="\\u")
|
299 |
+
|
300 |
+
|
301 |
+
def emp_distro_from_corpus(corpus, char_list):
|
302 |
+
from collections import Counter
|
303 |
+
c = Counter(char for task in corpus for example in task.examples for string in example[1] for char in string)
|
304 |
+
n = sum(c[char] for char in char_list)
|
305 |
+
return [c[char]/n for char in char_list]
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
def matchEmpericalPrimitives(corpus):
|
310 |
+
return lambda: [
|
311 |
+
Primitive("empty_string", tpregex, pregex.String(""))
|
312 |
+
] + [
|
313 |
+
Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
|
314 |
+
] + [
|
315 |
+
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
|
316 |
+
] + [
|
317 |
+
Primitive("r_dot", tpregex, emp_dot(corpus) ),
|
318 |
+
Primitive("r_d", tpregex, emp_d(corpus) ),
|
319 |
+
Primitive("r_s", tpregex, pregex.s),
|
320 |
+
Primitive("r_w", tpregex, emp_w(corpus) ),
|
321 |
+
Primitive("r_l", tpregex, emp_l(corpus) ),
|
322 |
+
Primitive("r_u", tpregex, emp_u(corpus) ),
|
323 |
+
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
|
324 |
+
#Primitive("r_plus", arrow(tpregex, tpregex), _plus),
|
325 |
+
Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
|
326 |
+
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
|
327 |
+
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
|
328 |
+
]
|
329 |
+
|
330 |
+
def matchEmpericalNoLetterPrimitives(corpus):
|
331 |
+
return lambda: [
|
332 |
+
Primitive("empty_string", tpregex, pregex.String(""))
|
333 |
+
] + [
|
334 |
+
Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list + list(printable[10:62])
|
335 |
+
] + [
|
336 |
+
Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
|
337 |
+
] + [
|
338 |
+
Primitive("r_dot", tpregex, emp_dot_no_letter(corpus) ),
|
339 |
+
Primitive("r_d", tpregex, emp_d(corpus) ),
|
340 |
+
Primitive("r_s", tpregex, pregex.s),
|
341 |
+
Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
|
342 |
+
#Primitive("r_plus", arrow(tpregex, tpregex), _plus),
|
343 |
+
#Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
|
344 |
+
Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
|
345 |
+
Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
|
346 |
+
]
|
347 |
+
|
348 |
+
|
349 |
+
if __name__=='__main__':
|
350 |
+
concatPrimitives()
|
351 |
+
from dreamcoder.program import Program
|
352 |
+
|
353 |
+
p=Program.parse("(lambda (r_kleene (lambda (r_maybe (lambda (string_x $0)) $0)) $0))")
|
354 |
+
print(p)
|
355 |
+
print(p.runWithArguments([pregex.String("")]))
|
356 |
+
|
357 |
+
prims = concatPrimitives()
|
358 |
+
g = Grammar.uniform(prims)
|
359 |
+
|
360 |
+
for i in range(100):
|
361 |
+
prog = g.sample(arrow(tpregex,tpregex))
|
362 |
+
preg = prog.runWithArguments([pregex.String("")])
|
363 |
+
print("preg:", preg.__repr__())
|
364 |
+
print("sample:", preg.sample())
|
365 |
+
|
366 |
+
|
367 |
+
|
dreamcoder/domains/text/__init__.py
ADDED
File without changes
|
dreamcoder/domains/text/main.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.dreamcoder import ecIterator
|
2 |
+
from dreamcoder.domains.text.makeTextTasks import makeTasks, loadPBETasks
|
3 |
+
from dreamcoder.domains.text.textPrimitives import primitives
|
4 |
+
from dreamcoder.domains.list.listPrimitives import bootstrapTarget
|
5 |
+
from dreamcoder.enumeration import *
|
6 |
+
|
7 |
+
import os
|
8 |
+
import datetime
|
9 |
+
import random
|
10 |
+
from functools import reduce
|
11 |
+
import dill
|
12 |
+
|
13 |
+
|
14 |
+
class ConstantInstantiateVisitor(object):
|
15 |
+
def __init__(self, words):
|
16 |
+
self.words = words
|
17 |
+
|
18 |
+
def primitive(self, e):
|
19 |
+
if e.name == "STRING":
|
20 |
+
return Primitive("STRING", e.tp, random.choice(self.words))
|
21 |
+
return e
|
22 |
+
|
23 |
+
def invented(self, e): return e.body.visit(self)
|
24 |
+
|
25 |
+
def index(self, e): return e
|
26 |
+
|
27 |
+
def application(self, e):
|
28 |
+
return Application(e.f.visit(self), e.x.visit(self))
|
29 |
+
|
30 |
+
def abstraction(self, e):
|
31 |
+
return Abstraction(e.body.visit(self))
|
32 |
+
|
33 |
+
|
34 |
+
try:
|
35 |
+
from dreamcoder.recognition import *
|
36 |
+
class LearnedFeatureExtractor(RecurrentFeatureExtractor):
|
37 |
+
special = 'string'
|
38 |
+
|
39 |
+
def tokenize(self, examples):
|
40 |
+
def tokenize_example(xs,y):
|
41 |
+
if not isinstance(y, list): y = [y]
|
42 |
+
return xs,y
|
43 |
+
return [tokenize_example(*e) for e in examples]
|
44 |
+
|
45 |
+
def __init__(self, tasks, testingTasks=[], cuda=False):
|
46 |
+
lexicon = {c
|
47 |
+
for t in tasks + testingTasks
|
48 |
+
for xs, y in self.tokenize(t.examples)
|
49 |
+
for c in reduce(lambda u, v: u + v, list(xs) + [y])}
|
50 |
+
self.recomputeTasks = True
|
51 |
+
|
52 |
+
super(LearnedFeatureExtractor, self).__init__(lexicon=list(lexicon),
|
53 |
+
H=64,
|
54 |
+
tasks=tasks,
|
55 |
+
bidirectional=True,
|
56 |
+
cuda=cuda)
|
57 |
+
self.MAXINPUTS = 8
|
58 |
+
|
59 |
+
def taskOfProgram(self, p, tp):
|
60 |
+
# Instantiate STRING w/ random words
|
61 |
+
p = p.visit(ConstantInstantiateVisitor.SINGLE)
|
62 |
+
return super(LearnedFeatureExtractor, self).taskOfProgram(p, tp)
|
63 |
+
except:
|
64 |
+
pass
|
65 |
+
|
66 |
+
### COMPETITION CODE
|
67 |
+
|
68 |
+
def competeOnOneTask(checkpoint, task,
|
69 |
+
CPUs=8, timeout=3600, evaluationTimeout=0.0005):
|
70 |
+
if checkpoint.recognitionModel is not None:
|
71 |
+
recognizer = checkpoint.recognitionModel
|
72 |
+
challengeFrontiers, times, bestSearchTime = \
|
73 |
+
recognizer.enumerateFrontiers([task],
|
74 |
+
CPUs=CPUs,
|
75 |
+
maximumFrontier=1,
|
76 |
+
enumerationTimeout=timeout,
|
77 |
+
evaluationTimeout=evaluationTimeout)
|
78 |
+
else:
|
79 |
+
challengeFrontiers, times, bestSearchTimes = \
|
80 |
+
multicoreEnumeration(checkpoint.grammars[-1], [task],
|
81 |
+
CPUs=CPUs,
|
82 |
+
maximumFrontier=1,
|
83 |
+
enumerationTimeout=timeout,
|
84 |
+
evaluationTimeout=evaluationTimeout)
|
85 |
+
if len(times) == 0: return None, task
|
86 |
+
assert len(times) == 1
|
87 |
+
return times[0], task
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
def sygusCompetition(checkpoints, tasks):
|
92 |
+
from pathos.multiprocessing import Pool
|
93 |
+
import datetime
|
94 |
+
|
95 |
+
# map from task to list of search times, one for each checkpoint.
|
96 |
+
# search time will be None if it is not solved
|
97 |
+
searchTimes = {t: [] for t in tasks}
|
98 |
+
|
99 |
+
CPUs = int(8/len(checkpoints))
|
100 |
+
maxWorkers = int(numberOfCPUs()/CPUs)
|
101 |
+
workers = Pool(maxWorkers)
|
102 |
+
eprint(f"You gave me {len(checkpoints)} checkpoints to ensemble. Each checkpoint will get {CPUs} CPUs. Creating a pool of {maxWorkers} worker processes.")
|
103 |
+
timeout = 3600
|
104 |
+
|
105 |
+
|
106 |
+
promises = []
|
107 |
+
for t in tasks:
|
108 |
+
for checkpoint in checkpoints:
|
109 |
+
promise = workers.apply_async(competeOnOneTask,
|
110 |
+
(checkpoint,t),
|
111 |
+
{"CPUs": CPUs,
|
112 |
+
"timeout": timeout})
|
113 |
+
promises.append(promise)
|
114 |
+
eprint(f"Queued {len(promises)} jobs.")
|
115 |
+
for promise in promises:
|
116 |
+
dt, task = promise.get()
|
117 |
+
if dt is not None:
|
118 |
+
searchTimes[task].append(dt)
|
119 |
+
|
120 |
+
searchTimes = {t: min(ts) if len(ts) > 0 else None
|
121 |
+
for t,ts in searchTimes.items()}
|
122 |
+
|
123 |
+
fn = "experimentOutputs/text_competition_%s.p"%(datetime.datetime.now().isoformat())
|
124 |
+
with open(fn,"wb") as handle:
|
125 |
+
pickle.dump(searchTimes, handle)
|
126 |
+
eprint()
|
127 |
+
|
128 |
+
hits = sum( t is not None for t in searchTimes.values() )
|
129 |
+
total = len(searchTimes)
|
130 |
+
percentage = 100*hits/total
|
131 |
+
eprint("Hits %d/%d = %f\n"%(hits, total, percentage))
|
132 |
+
eprint()
|
133 |
+
eprint("Exported competition results to",fn)
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
def text_options(parser):
|
138 |
+
parser.add_argument(
|
139 |
+
"--showTasks",
|
140 |
+
action="store_true",
|
141 |
+
default=False,
|
142 |
+
help="show the training test and challenge tasks and then exit")
|
143 |
+
parser.add_argument(
|
144 |
+
"--trainChallenge",
|
145 |
+
action="store_true",
|
146 |
+
default=False,
|
147 |
+
help="Incorporate a random 50% of the challenge problems into the training set")
|
148 |
+
parser.add_argument(
|
149 |
+
"--onlyChallenge",
|
150 |
+
action="store_true",
|
151 |
+
default=False,
|
152 |
+
help="Only train on challenge problems and have testing problems.")
|
153 |
+
parser.add_argument(
|
154 |
+
"--latest",
|
155 |
+
action="store_true",
|
156 |
+
default=False,
|
157 |
+
help="evaluate on latest sygus problems rather than problems used in ec2 paper")
|
158 |
+
parser.add_argument(
|
159 |
+
"--noMap", action="store_true", default=False,
|
160 |
+
help="Disable built-in map primitive")
|
161 |
+
parser.add_argument(
|
162 |
+
"--noLength", action="store_true", default=False,
|
163 |
+
help="Disable built-in length primitive")
|
164 |
+
parser.add_argument(
|
165 |
+
"--noUnfold", action="store_true", default=False,
|
166 |
+
help="Disable built-in unfold primitive")
|
167 |
+
parser.add_argument(
|
168 |
+
"--compete",
|
169 |
+
nargs='+',
|
170 |
+
default=None,
|
171 |
+
type=str,
|
172 |
+
help="Do a simulated sygus competition (1hr+8cpus/problem) on the sygus tasks, restoring from provided checkpoint(s). If multiple checkpoints are provided, then we ensemble the models.")
|
173 |
+
|
174 |
+
|
175 |
+
def main(arguments):
|
176 |
+
"""
|
177 |
+
Takes the return value of the `commandlineArguments()` function as input and
|
178 |
+
trains/tests the model on manipulating sequences of text.
|
179 |
+
"""
|
180 |
+
|
181 |
+
tasks = makeTasks()
|
182 |
+
eprint("Generated", len(tasks), "tasks")
|
183 |
+
|
184 |
+
for t in tasks:
|
185 |
+
t.mustTrain = False
|
186 |
+
|
187 |
+
test, train = testTrainSplit(tasks, 1.)
|
188 |
+
eprint("Split tasks into %d/%d test/train" % (len(test), len(train)))
|
189 |
+
|
190 |
+
latest = arguments.pop("latest")
|
191 |
+
challenge, challengeCheating = loadPBETasks("data/sygus" if latest else "PBE_Strings_Track")
|
192 |
+
eprint("Got %d challenge PBE tasks" % len(challenge))
|
193 |
+
|
194 |
+
if arguments.pop('trainChallenge'):
|
195 |
+
challengeTest, challengeTrain = testTrainSplit(challenge, 0.5)
|
196 |
+
challenge = challengeTest
|
197 |
+
train += challengeTrain
|
198 |
+
eprint(
|
199 |
+
"Incorporating %d (50%%) challenge problems into the training set." %
|
200 |
+
(len(challengeTrain)),
|
201 |
+
"We will evaluate on the held out challenge problems.",
|
202 |
+
"This makes a total of %d training problems." %
|
203 |
+
len(train))
|
204 |
+
|
205 |
+
if arguments.pop('onlyChallenge'):
|
206 |
+
train = challenge
|
207 |
+
test = []
|
208 |
+
challenge = []
|
209 |
+
eprint("Training only on sygus problems.")
|
210 |
+
|
211 |
+
|
212 |
+
ConstantInstantiateVisitor.SINGLE = \
|
213 |
+
ConstantInstantiateVisitor(list(map(list, list({tuple([c for c in s])
|
214 |
+
for t in test + train + challenge
|
215 |
+
for s in t.stringConstants}))))
|
216 |
+
|
217 |
+
haveLength = not arguments.pop("noLength")
|
218 |
+
haveMap = not arguments.pop("noMap")
|
219 |
+
haveUnfold = not arguments.pop("noUnfold")
|
220 |
+
eprint(f"Including map as a primitive? {haveMap}")
|
221 |
+
eprint(f"Including length as a primitive? {haveLength}")
|
222 |
+
eprint(f"Including unfold as a primitive? {haveUnfold}")
|
223 |
+
baseGrammar = Grammar.uniform(primitives + [p
|
224 |
+
for p in bootstrapTarget()
|
225 |
+
if (p.name != "map" or haveMap) and \
|
226 |
+
(p.name != "unfold" or haveUnfold) and \
|
227 |
+
(p.name != "length" or haveLength)])
|
228 |
+
challengeGrammar = baseGrammar # Grammar.uniform(targetTextPrimitives)
|
229 |
+
|
230 |
+
evaluationTimeout = 0.0005
|
231 |
+
# We will spend 10 minutes on each challenge problem
|
232 |
+
challengeTimeout = 10 * 60
|
233 |
+
|
234 |
+
for t in train + test + challenge:
|
235 |
+
t.maxParameters = 2
|
236 |
+
|
237 |
+
if arguments.pop("showTasks"):
|
238 |
+
for source, ts in [("train",tasks),("test",test),("challenge",challenge)]:
|
239 |
+
print(source,"tasks:")
|
240 |
+
for t in ts:
|
241 |
+
print(t.name)
|
242 |
+
for xs, y in t.examples:
|
243 |
+
xs = ['"' + "".join(x) + '"' for x in xs]
|
244 |
+
y = "".join(y) if isinstance(y,list) else y
|
245 |
+
print('f(%s) = "%s"' % (", ".join(xs), y))
|
246 |
+
print("\t{%s}" % (t.stringConstants))
|
247 |
+
print()
|
248 |
+
sys.exit(0)
|
249 |
+
|
250 |
+
|
251 |
+
competitionCheckpoints = arguments.pop("compete")
|
252 |
+
if competitionCheckpoints:
|
253 |
+
checkpoints = []
|
254 |
+
for competitionCheckpoint in competitionCheckpoints:
|
255 |
+
with open(competitionCheckpoint, 'rb') as handle:
|
256 |
+
checkpoints.append(dill.load(handle))
|
257 |
+
sygusCompetition(checkpoints, challenge)
|
258 |
+
sys.exit(0)
|
259 |
+
|
260 |
+
timestamp = datetime.datetime.now().isoformat()
|
261 |
+
outputDirectory = "experimentOutputs/text/%s"%timestamp
|
262 |
+
os.system("mkdir -p %s"%outputDirectory)
|
263 |
+
|
264 |
+
generator = ecIterator(baseGrammar, train,
|
265 |
+
testingTasks=test + challenge,
|
266 |
+
outputPrefix="%s/text"%outputDirectory,
|
267 |
+
evaluationTimeout=evaluationTimeout,
|
268 |
+
**arguments)
|
269 |
+
for result in generator:
|
270 |
+
pass
|
dreamcoder/domains/text/makeTextTasks.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.task import *
|
2 |
+
from dreamcoder.type import *
|
3 |
+
from dreamcoder.utilities import *
|
4 |
+
|
5 |
+
import random
|
6 |
+
|
7 |
+
|
8 |
+
def lcs(u, v):
|
9 |
+
# t[(n,m)] = length of longest common string ending at first
|
10 |
+
# n elements of u & first m elements of v
|
11 |
+
t = {}
|
12 |
+
|
13 |
+
for n in range(len(u) + 1):
|
14 |
+
for m in range(len(v) + 1):
|
15 |
+
if m == 0 or n == 0:
|
16 |
+
t[(n, m)] = 0
|
17 |
+
continue
|
18 |
+
|
19 |
+
if u[n - 1] == v[m - 1]:
|
20 |
+
t[(n, m)] = 1 + t[(n - 1, m - 1)]
|
21 |
+
else:
|
22 |
+
t[(n, m)] = 0
|
23 |
+
l, n, m = max((l, n, m) for (n, m), l in t.items())
|
24 |
+
return u[n - l:n]
|
25 |
+
|
26 |
+
|
27 |
+
delimiters = ['.', ',', ' ', '(', ')', '-']
|
28 |
+
characters = [chr(ord('a') + j)
|
29 |
+
for j in range(26)] + \
|
30 |
+
[chr(ord('A') + j)
|
31 |
+
for j in range(26)] + \
|
32 |
+
[str(j) for j in range(10)] + \
|
33 |
+
['+']
|
34 |
+
|
35 |
+
WORDS = None
|
36 |
+
|
37 |
+
|
38 |
+
def randomDelimiter():
|
39 |
+
return random.choice(delimiters)
|
40 |
+
|
41 |
+
|
42 |
+
def randomCharacter():
|
43 |
+
return random.choice(characters)
|
44 |
+
|
45 |
+
|
46 |
+
def randomWord(minimum=1, predicate=None):
|
47 |
+
global WORDS
|
48 |
+
if WORDS is None:
|
49 |
+
tasks, cheating = loadPBETasks()
|
50 |
+
observations = {''.join(z)
|
51 |
+
for t in tasks
|
52 |
+
for xs, y in t.examples
|
53 |
+
for z in list(xs) + [y]}
|
54 |
+
|
55 |
+
def splitMany(s, ds):
|
56 |
+
if len(ds) == 0:
|
57 |
+
return [s]
|
58 |
+
d = ds[0]
|
59 |
+
ds = ds[1:]
|
60 |
+
s = [w
|
61 |
+
for z in s.split(d)
|
62 |
+
for w in splitMany(z, ds)
|
63 |
+
if len(w) > 0]
|
64 |
+
return s
|
65 |
+
|
66 |
+
WORDS = {w
|
67 |
+
for o in observations
|
68 |
+
for w in splitMany(o, delimiters)}
|
69 |
+
WORDS = list(sorted(list(WORDS)))
|
70 |
+
|
71 |
+
# a disproportionately large fraction of the words have length three
|
72 |
+
# the purpose of this is to decrease the number of 3-length words we have
|
73 |
+
while True:
|
74 |
+
if random.random() > 0.7:
|
75 |
+
candidate = random.choice([w for w in WORDS if len(w) >= minimum])
|
76 |
+
else:
|
77 |
+
candidate = random.choice(
|
78 |
+
[w for w in WORDS if len(w) >= minimum and len(w) != 3])
|
79 |
+
if predicate is None or predicate(candidate):
|
80 |
+
return candidate
|
81 |
+
|
82 |
+
|
83 |
+
def randomWords(ds, minimum=1, lb=2, ub=4):
|
84 |
+
words = [randomWord(minimum=minimum)
|
85 |
+
for _ in range(random.choice(range(lb, ub+1)))]
|
86 |
+
s = ""
|
87 |
+
for j,w in enumerate(words):
|
88 |
+
if j > 0:
|
89 |
+
s += random.choice(ds)
|
90 |
+
s += w
|
91 |
+
return s
|
92 |
+
|
93 |
+
|
94 |
+
def makeTasks():
|
95 |
+
import random
|
96 |
+
random.seed(9)
|
97 |
+
|
98 |
+
NUMBEROFEXAMPLES = 4
|
99 |
+
|
100 |
+
problems = []
|
101 |
+
|
102 |
+
def toList(s): return [c for c in s]
|
103 |
+
# Converts strings into a list of characters depending on the type
|
104 |
+
|
105 |
+
def preprocess(x):
|
106 |
+
if isinstance(x, tuple):
|
107 |
+
return tuple(preprocess(z) for z in x)
|
108 |
+
if isinstance(x, list):
|
109 |
+
return [preprocess(z) for z in x]
|
110 |
+
if isinstance(x, str):
|
111 |
+
return [c for c in x]
|
112 |
+
if isinstance(x, bool):
|
113 |
+
return x
|
114 |
+
assert False
|
115 |
+
|
116 |
+
def problem(n, examples, needToTrain=False):
|
117 |
+
task = Task(n, guess_arrow_type(examples),
|
118 |
+
[(preprocess(x),
|
119 |
+
preprocess(y))
|
120 |
+
for x, y in examples])
|
121 |
+
task.mustTrain = True
|
122 |
+
problems.append(task)
|
123 |
+
|
124 |
+
for d1, d2 in randomPermutation(crossProduct(delimiters, delimiters))[
|
125 |
+
:len(delimiters) * 2]:
|
126 |
+
if d1 != d2:
|
127 |
+
problem("Replace '%s' w/ '%s'" % (d1, d2),
|
128 |
+
[((x,), x.replace(d1, d2))
|
129 |
+
for _ in range(NUMBEROFEXAMPLES)
|
130 |
+
for x in [randomWords(d1)]],
|
131 |
+
needToTrain=False)
|
132 |
+
for d in delimiters:
|
133 |
+
problem("drop first word delimited by '%s'" % d,
|
134 |
+
[((x,), d.join(x.split(d)[1:]))
|
135 |
+
for _ in range(NUMBEROFEXAMPLES)
|
136 |
+
for x in [randomWords(d)]],
|
137 |
+
needToTrain=True)
|
138 |
+
for n in [0, 1, -1]:
|
139 |
+
problem("nth (n=%d) word delimited by '%s'" % (n, d),
|
140 |
+
[((x,), x.split(d)[n])
|
141 |
+
for _ in range(NUMBEROFEXAMPLES)
|
142 |
+
for x in [randomWords(d)]],
|
143 |
+
needToTrain=True)
|
144 |
+
for d1 in delimiters:
|
145 |
+
problem("Append two words delimited by '%s'" % (d1),
|
146 |
+
[((x, y), x + d1 + y)
|
147 |
+
for _ in range(NUMBEROFEXAMPLES)
|
148 |
+
for x in [randomWord()]
|
149 |
+
for y in [randomWord()]],
|
150 |
+
needToTrain=True)
|
151 |
+
for d1, d2 in randomPermutation(
|
152 |
+
crossProduct(
|
153 |
+
delimiters, delimiters))[
|
154 |
+
:len(delimiters)]:
|
155 |
+
problem("Append two words delimited by '%s%s'" % (d1, d2),
|
156 |
+
[((x, y), x + d1 + d2 + y)
|
157 |
+
for _ in range(NUMBEROFEXAMPLES)
|
158 |
+
for x in [randomWord()]
|
159 |
+
for y in [randomWord()]],
|
160 |
+
needToTrain=True)
|
161 |
+
for n in range(1, 6):
|
162 |
+
problem("Drop last %d characters" % n,
|
163 |
+
[((x,), x[:-n])
|
164 |
+
for _ in range(NUMBEROFEXAMPLES)
|
165 |
+
for x in [randomWord(minimum=n)]],
|
166 |
+
needToTrain=True)
|
167 |
+
if n > 1:
|
168 |
+
problem("Take first %d characters" % n,
|
169 |
+
[((x,), x[:n])
|
170 |
+
for _ in range(NUMBEROFEXAMPLES)
|
171 |
+
for x in [randomWord(minimum=n)]],
|
172 |
+
needToTrain=True)
|
173 |
+
for d1, d2 in randomPermutation(
|
174 |
+
crossProduct(
|
175 |
+
delimiters, delimiters))[
|
176 |
+
:len(delimiters)]:
|
177 |
+
problem("Extract word delimited by '%s' - '%s'" % (d1, d2),
|
178 |
+
[((a + d1 + b + d2 + c + d + e,), b)
|
179 |
+
for _ in range(int(NUMBEROFEXAMPLES / 2))
|
180 |
+
for d in [d1, d2]
|
181 |
+
for a in [randomWord()]
|
182 |
+
for b in [randomWord()]
|
183 |
+
for c in [randomWord()]
|
184 |
+
for e in [randomWord()]],
|
185 |
+
needToTrain=True)
|
186 |
+
|
187 |
+
for n in range(len(delimiters)):
|
188 |
+
problem("First letters of words (%s)" % ("I" * (1 + n)),
|
189 |
+
[((x,), "".join(map(lambda z: z[0], x.split(' '))))
|
190 |
+
for _ in range(NUMBEROFEXAMPLES)
|
191 |
+
for x in [randomWords(' ')]
|
192 |
+
],
|
193 |
+
needToTrain=True)
|
194 |
+
|
195 |
+
for d in delimiters:
|
196 |
+
problem("Take first character and append '%s'" % d,
|
197 |
+
[((x,), x[0] + d)
|
198 |
+
for _ in range(NUMBEROFEXAMPLES)
|
199 |
+
for x in [randomWord()]],
|
200 |
+
needToTrain=True)
|
201 |
+
|
202 |
+
for n in range(len(delimiters)):
|
203 |
+
problem("Abbreviate separate words (%s)" % ("I" * (n + 1)),
|
204 |
+
[((x, y), "%s.%s." % (x[0], y[0]))
|
205 |
+
for _ in range(NUMBEROFEXAMPLES)
|
206 |
+
for y in [randomWord()]
|
207 |
+
for x in [randomWord()]])
|
208 |
+
d = delimiters[n]
|
209 |
+
problem("Abbreviate words separated by '%s'" % d,
|
210 |
+
[((x + d + y,), "%s.%s." % (x[0], y[0]))
|
211 |
+
for _ in range(NUMBEROFEXAMPLES)
|
212 |
+
for y in [randomWord()]
|
213 |
+
for x in [randomWord()]])
|
214 |
+
|
215 |
+
for n in range(len(delimiters)):
|
216 |
+
problem("Append 2 strings (%s)" % ('I' * (n + 1)),
|
217 |
+
[((x, y), x + y)
|
218 |
+
for _ in range(NUMBEROFEXAMPLES)
|
219 |
+
for y in [randomWord()]
|
220 |
+
for x in [randomWord()]],
|
221 |
+
needToTrain=True)
|
222 |
+
|
223 |
+
for n in range(len(delimiters)):
|
224 |
+
w = randomWord(minimum=3)
|
225 |
+
problem("Prepend '%s'" % w,
|
226 |
+
[((x,), w + x)
|
227 |
+
for _ in range(NUMBEROFEXAMPLES)
|
228 |
+
for x in [randomWord()]])
|
229 |
+
w = randomWord(minimum=3)
|
230 |
+
problem("Append '%s'" % w,
|
231 |
+
[((x,), x + w)
|
232 |
+
for _ in range(NUMBEROFEXAMPLES)
|
233 |
+
for x in [randomWord()]])
|
234 |
+
w = randomWord(minimum=3)
|
235 |
+
problem("Prepend '%s' to first word" % w,
|
236 |
+
[((x + ' ' + y,), w + x)
|
237 |
+
for _ in range(NUMBEROFEXAMPLES)
|
238 |
+
for x in [randomWord()]
|
239 |
+
for y in [randomWord()]])
|
240 |
+
|
241 |
+
for n in range(1,6):
|
242 |
+
problem("parentheses around a single word (%s)"%('I'*n),
|
243 |
+
[((w,),"(%s)"%w)
|
244 |
+
for _ in range(NUMBEROFEXAMPLES)
|
245 |
+
for w in [randomWord()] ])
|
246 |
+
problem("parentheses around first word",
|
247 |
+
[((w + " " + s,),"(%s)"%w)
|
248 |
+
for _ in range(NUMBEROFEXAMPLES)
|
249 |
+
for w in [randomWord()]
|
250 |
+
for s in [randomWords(" ")] ])
|
251 |
+
problem("parentheses around second word",
|
252 |
+
[((s,), "(%s)"%(s.split(" ")[1]))
|
253 |
+
for _ in range(NUMBEROFEXAMPLES)
|
254 |
+
for s in [randomWords(" ")] ])
|
255 |
+
|
256 |
+
allowed = [d for d in delimiters if d not in "()"]
|
257 |
+
for d1,d2 in randomPermutation(crossProduct(allowed, allowed))[:len(delimiters)]:
|
258 |
+
problem("parentheses around word delimited by '%s' & '%s'"%(d1,d2),
|
259 |
+
[((prefix + d1 + word + d2 + suffix,),
|
260 |
+
prefix + d1 + '(' + word + ')' + d2 + suffix)
|
261 |
+
for _ in range(NUMBEROFEXAMPLES)
|
262 |
+
for prefix in [randomWords("", lb=0, ub=1)]
|
263 |
+
for suffix in [randomWords(allowed, ub=2, lb=1)]
|
264 |
+
for word in [randomWord()] ])
|
265 |
+
|
266 |
+
for n in range(7):
|
267 |
+
w = randomWord(minimum=3)
|
268 |
+
problem("ensure suffix `%s`"%w,
|
269 |
+
[ ((s + (w if f else ""),), s + w)
|
270 |
+
for _ in range(NUMBEROFEXAMPLES)
|
271 |
+
for s in [randomWords(" ")]
|
272 |
+
for f in [random.choice([True,False])] ])
|
273 |
+
|
274 |
+
|
275 |
+
for p in problems:
|
276 |
+
guessConstantStrings(p)
|
277 |
+
|
278 |
+
return problems
|
279 |
+
|
280 |
+
|
281 |
+
def loadPBETasks(directory="PBE_Strings_Track"):
|
282 |
+
"""
|
283 |
+
Processes sygus benchmarks into task objects
|
284 |
+
For these benchmarks, all of the constant strings are given to us.
|
285 |
+
In a sense this is cheating
|
286 |
+
Returns (tasksWithoutCheating, tasksWithCheating).
|
287 |
+
NB: Results in paper are done without "cheating"
|
288 |
+
"""
|
289 |
+
import os
|
290 |
+
from sexpdata import loads, Symbol
|
291 |
+
|
292 |
+
def findStrings(s):
|
293 |
+
if isinstance(s, list):
|
294 |
+
return [y
|
295 |
+
for x in s
|
296 |
+
for y in findStrings(x)]
|
297 |
+
if isinstance(s, str):
|
298 |
+
return [s]
|
299 |
+
return []
|
300 |
+
|
301 |
+
def explode(s):
|
302 |
+
return [c for c in s]
|
303 |
+
|
304 |
+
tasks = []
|
305 |
+
cheatingTasks = []
|
306 |
+
for f in os.listdir(directory):
|
307 |
+
if not f.endswith('.sl'):
|
308 |
+
continue
|
309 |
+
with open(directory + "/" + f, "r") as handle:
|
310 |
+
message = "(%s)" % (handle.read())
|
311 |
+
|
312 |
+
expression = loads(message)
|
313 |
+
|
314 |
+
constants = []
|
315 |
+
name = f
|
316 |
+
examples = []
|
317 |
+
declarative = False
|
318 |
+
for e in expression:
|
319 |
+
if len(e) == 0:
|
320 |
+
continue
|
321 |
+
if e[0] == Symbol('constraint'):
|
322 |
+
e = e[1]
|
323 |
+
assert e[0] == Symbol('=')
|
324 |
+
inputs = e[1]
|
325 |
+
assert inputs[0] == Symbol('f')
|
326 |
+
inputs = inputs[1:]
|
327 |
+
output = e[2]
|
328 |
+
examples.append((inputs, output))
|
329 |
+
elif e[0] == Symbol('synth-fun'):
|
330 |
+
if e[1] == Symbol('f'):
|
331 |
+
constants += findStrings(e)
|
332 |
+
else:
|
333 |
+
declarative = True
|
334 |
+
break
|
335 |
+
if declarative: continue
|
336 |
+
|
337 |
+
examples = list({(tuple(xs), y) for xs, y in examples})
|
338 |
+
|
339 |
+
task = Task(name, arrow(*[tstr] * (len(examples[0][0]) + 1)),
|
340 |
+
[(tuple(map(explode, xs)), explode(y))
|
341 |
+
for xs, y in examples])
|
342 |
+
cheat = task
|
343 |
+
|
344 |
+
tasks.append(task)
|
345 |
+
cheatingTasks.append(cheat)
|
346 |
+
|
347 |
+
for p in tasks:
|
348 |
+
guessConstantStrings(p)
|
349 |
+
return tasks, cheatingTasks
|
350 |
+
|
351 |
+
|
352 |
+
def guessConstantStrings(task):
|
353 |
+
if task.request.returns() == tlist(tcharacter):
|
354 |
+
examples = task.examples
|
355 |
+
guesses = {}
|
356 |
+
N = 10
|
357 |
+
T = 2
|
358 |
+
for n in range(min(N, len(examples))):
|
359 |
+
for m in range(n + 1, min(N, len(examples))):
|
360 |
+
y1 = examples[n][1]
|
361 |
+
y2 = examples[m][1]
|
362 |
+
l = ''.join(lcs(y1, y2))
|
363 |
+
if len(l) > 2:
|
364 |
+
guesses[l] = guesses.get(l, 0) + 1
|
365 |
+
|
366 |
+
task.stringConstants = [g for g, f in guesses.items()
|
367 |
+
if f >= T]
|
368 |
+
else:
|
369 |
+
task.stringConstants = []
|
370 |
+
|
371 |
+
|
372 |
+
task.BIC = 1.
|
373 |
+
task.maxParameters = 1
|
374 |
+
|
375 |
+
task.specialTask = ("stringConstant",
|
376 |
+
{"maxParameters": task.maxParameters,
|
377 |
+
"stringConstants": task.stringConstants})
|
378 |
+
|
379 |
+
|
380 |
+
if __name__ == "__main__":
|
381 |
+
challenge, _ = loadPBETasks("data/sygus")
|
382 |
+
|
383 |
+
tasks = makeTasks()
|
384 |
+
print(len(tasks), "synthetic tasks")
|
385 |
+
tasks = []
|
386 |
+
for t in tasks + challenge:
|
387 |
+
print(t.name)
|
388 |
+
for xs, y in t.examples:
|
389 |
+
xs = ['"' + "".join(x) + '"' for x in xs]
|
390 |
+
y = "".join(y)
|
391 |
+
print('f(%s) = "%s"' % (", ".join(xs), y))
|
392 |
+
print("\t{%s}" % (t.stringConstants))
|
393 |
+
print()
|
394 |
+
assert False
|
395 |
+
# def maximumLength(x):
|
396 |
+
# if isinstance(x,list):
|
397 |
+
# return max([len(x)] + map(maximumLength,x))
|
398 |
+
# return 1
|
399 |
+
|
400 |
+
# print max(maximumLength(z) for t in tasks
|
401 |
+
# for (x,),y in t.examples
|
402 |
+
# for z in [x,y] )
|
403 |
+
|
404 |
+
if len(sys.argv) > 1 and "json" in sys.argv[1]:
|
405 |
+
import json
|
406 |
+
tasks = makeTasks()
|
407 |
+
obj = [t.as_json_dict() for t in tasks]
|
408 |
+
json.dump(obj, sys.stdout)
|
409 |
+
else:
|
410 |
+
as_tex = len(sys.argv) > 1 and "tex" in sys.argv[1]
|
411 |
+
for t in tasks:
|
412 |
+
print(t.name)
|
413 |
+
print(t.request)
|
414 |
+
if as_tex:
|
415 |
+
print("""\\begin{tabular}{ll}
|
416 |
+
\\toprule Input&Output\\\\\\midrule
|
417 |
+
%s
|
418 |
+
\\\\\\bottomrule
|
419 |
+
\\end{tabular}""" % (" \\\\\n ".join(x[0] + " & " + y for x, y in t.examples)))
|
420 |
+
else:
|
421 |
+
for x, y in t.examples:
|
422 |
+
print(x[0], '\t', y)
|
423 |
+
print()
|
424 |
+
print(len(tasks), "tasks")
|
dreamcoder/domains/text/textPrimitives.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.program import *
|
2 |
+
from dreamcoder.domains.text.makeTextTasks import delimiters
|
3 |
+
|
4 |
+
def _isUpper(x): return x.isupper()
|
5 |
+
|
6 |
+
def _increment(x): return x + 1
|
7 |
+
|
8 |
+
|
9 |
+
def _decrement(x): return x - 1
|
10 |
+
|
11 |
+
|
12 |
+
def _lower(x): return x.lower()
|
13 |
+
|
14 |
+
|
15 |
+
def _upper(x): return x.upper()
|
16 |
+
|
17 |
+
|
18 |
+
def _capitalize(x): return x.capitalize()
|
19 |
+
|
20 |
+
|
21 |
+
def _append(x): return lambda y: x + y
|
22 |
+
|
23 |
+
|
24 |
+
def _slice(x): return lambda y: lambda s: s[x:y]
|
25 |
+
|
26 |
+
|
27 |
+
def _index(n): return lambda x: x[n]
|
28 |
+
|
29 |
+
|
30 |
+
def _map(f): return lambda x: list(map(f, x))
|
31 |
+
|
32 |
+
|
33 |
+
def _find(pattern): return lambda s: s.index(pattern)
|
34 |
+
|
35 |
+
|
36 |
+
def _replace(original): return lambda replacement: lambda target: target.replace(
|
37 |
+
original, replacement)
|
38 |
+
|
39 |
+
|
40 |
+
def _split(delimiter): return lambda s: s.split(delimiter)
|
41 |
+
|
42 |
+
|
43 |
+
def _join(delimiter): return lambda ss: delimiter.join(ss)
|
44 |
+
|
45 |
+
|
46 |
+
def _identity(x): return x
|
47 |
+
#def _reverse(x): return x[::-1]
|
48 |
+
|
49 |
+
|
50 |
+
def _strip(x): return x.strip()
|
51 |
+
|
52 |
+
|
53 |
+
def _eq(x): return lambda y: x == y
|
54 |
+
|
55 |
+
|
56 |
+
specialCharacters = {' ': 'SPACE',
|
57 |
+
')': 'RPAREN',
|
58 |
+
'(': 'LPAREN'}
|
59 |
+
|
60 |
+
primitives = [
|
61 |
+
Primitive("char-eq?", arrow(tcharacter, tcharacter, tboolean), _eq),
|
62 |
+
Primitive("STRING", tstr, None)
|
63 |
+
] + [Primitive("'%s'" % d, tcharacter, d) for d in delimiters if d not in specialCharacters] + \
|
64 |
+
[Primitive(name, tcharacter, value) for value, name in specialCharacters.items()]
|
65 |
+
|
66 |
+
|
67 |
+
def _cons(x): return lambda y: [x] + y
|
68 |
+
|
69 |
+
|
70 |
+
def _car(x): return x[0]
|
71 |
+
|
72 |
+
|
73 |
+
def _cdr(x): return x[1:]
|
74 |
+
|
75 |
+
|
76 |
+
targetTextPrimitives = [
|
77 |
+
Primitive("take-word", arrow(tcharacter, tstr, tstr), None),
|
78 |
+
Primitive("drop-word", arrow(tcharacter, tstr, tstr), None),
|
79 |
+
Primitive("append", arrow(tlist(t0), tlist(t0), tlist(t0)), None),
|
80 |
+
Primitive("abbreviate", arrow(tstr, tstr), None),
|
81 |
+
Primitive("last-word", arrow(tcharacter, tstr, tstr), None),
|
82 |
+
Primitive("replace-character", arrow(tcharacter, tcharacter, tstr, tstr), None),
|
83 |
+
] + primitives + [
|
84 |
+
Primitive("empty", tlist(t0), []),
|
85 |
+
Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
|
86 |
+
Primitive("car", arrow(tlist(t0), t0), _car),
|
87 |
+
Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr)]
|
dreamcoder/domains/tower/__init__.py
ADDED
File without changes
|
dreamcoder/domains/tower/main.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.dreamcoder import *
|
2 |
+
|
3 |
+
from dreamcoder.domains.tower.towerPrimitives import primitives, new_primitives, animateTower
|
4 |
+
from dreamcoder.domains.tower.makeTowerTasks import *
|
5 |
+
from dreamcoder.domains.tower.tower_common import renderPlan, towerLength, centerTower
|
6 |
+
from dreamcoder.utilities import *
|
7 |
+
|
8 |
+
import os
|
9 |
+
import datetime
|
10 |
+
|
11 |
+
try: #pypy will fail
|
12 |
+
from dreamcoder.recognition import variable
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
class Flatten(nn.Module):
|
17 |
+
def __init__(self):
|
18 |
+
super(Flatten, self).__init__()
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return x.view(x.size(0), -1)
|
22 |
+
|
23 |
+
|
24 |
+
class TowerCNN(nn.Module):
|
25 |
+
special = 'tower'
|
26 |
+
|
27 |
+
def __init__(self, tasks, testingTasks=[], cuda=False, H=64):
|
28 |
+
super(TowerCNN, self).__init__()
|
29 |
+
self.CUDA = cuda
|
30 |
+
self.recomputeTasks = True
|
31 |
+
|
32 |
+
self.outputDimensionality = H
|
33 |
+
def conv_block(in_channels, out_channels):
|
34 |
+
return nn.Sequential(
|
35 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
36 |
+
# nn.BatchNorm2d(out_channels),
|
37 |
+
nn.ReLU(),
|
38 |
+
nn.MaxPool2d(2)
|
39 |
+
)
|
40 |
+
|
41 |
+
self.inputImageDimension = 256
|
42 |
+
self.resizedDimension = 64
|
43 |
+
assert self.inputImageDimension % self.resizedDimension == 0
|
44 |
+
|
45 |
+
# channels for hidden
|
46 |
+
hid_dim = 64
|
47 |
+
z_dim = 64
|
48 |
+
|
49 |
+
self.encoder = nn.Sequential(
|
50 |
+
conv_block(6, hid_dim),
|
51 |
+
conv_block(hid_dim, hid_dim),
|
52 |
+
conv_block(hid_dim, hid_dim),
|
53 |
+
conv_block(hid_dim, z_dim),
|
54 |
+
Flatten()
|
55 |
+
)
|
56 |
+
|
57 |
+
self.outputDimensionality = 1024
|
58 |
+
|
59 |
+
if cuda:
|
60 |
+
self.CUDA=True
|
61 |
+
self.cuda() # I think this should work?
|
62 |
+
|
63 |
+
def forward(self, v, v2=None):
|
64 |
+
"""v: tower to build. v2: image of tower we have built so far"""
|
65 |
+
# insert batch if it is not already there
|
66 |
+
if len(v.shape) == 3:
|
67 |
+
v = np.expand_dims(v, 0)
|
68 |
+
inserted_batch = True
|
69 |
+
if v2 is not None:
|
70 |
+
assert len(v2.shape) == 3
|
71 |
+
v2 = np.expand_dims(v2, 0)
|
72 |
+
elif len(v.shape) == 4:
|
73 |
+
inserted_batch = False
|
74 |
+
pass
|
75 |
+
else:
|
76 |
+
assert False, "v has the shape %s"%(str(v.shape))
|
77 |
+
|
78 |
+
if v2 is None: v2 = np.zeros(v.shape)
|
79 |
+
|
80 |
+
v = np.concatenate((v,v2), axis=3)
|
81 |
+
v = np.transpose(v,(0,3,1,2))
|
82 |
+
assert v.shape == (v.shape[0], 6,self.inputImageDimension,self.inputImageDimension)
|
83 |
+
v = variable(v, cuda=self.CUDA).float()
|
84 |
+
window = int(self.inputImageDimension/self.resizedDimension)
|
85 |
+
v = F.avg_pool2d(v, (window,window))
|
86 |
+
#showArrayAsImage(np.transpose(v.data.numpy()[0,:3,:,:],[1,2,0]))
|
87 |
+
v = self.encoder(v)
|
88 |
+
if inserted_batch:
|
89 |
+
return v.view(-1)
|
90 |
+
else:
|
91 |
+
return v
|
92 |
+
|
93 |
+
def featuresOfTask(self, t, t2=None): # Take a task and returns [features]
|
94 |
+
return self(t.getImage(),
|
95 |
+
None if t2 is None else t2.getImage(drawHand=True))
|
96 |
+
|
97 |
+
def featuresOfTasks(self, ts, t2=None): # Take a task and returns [features]
|
98 |
+
"""Takes the goal first; optionally also takes the current state second"""
|
99 |
+
if t2 is None:
|
100 |
+
pass
|
101 |
+
elif isinstance(t2, Task):
|
102 |
+
assert False
|
103 |
+
#t2 = np.array([t2.getImage(drawHand=True)]*len(ts))
|
104 |
+
elif isinstance(t2, list):
|
105 |
+
t2 = np.array([t.getImage(drawHand=True) if t else np.zeros((self.inputImageDimension,
|
106 |
+
self.inputImageDimension,
|
107 |
+
3))
|
108 |
+
for t in t2])
|
109 |
+
else:
|
110 |
+
assert False
|
111 |
+
|
112 |
+
return self(np.array([t.getImage() for t in ts]),
|
113 |
+
t2)
|
114 |
+
|
115 |
+
def taskOfProgram(self, p, t,
|
116 |
+
lenient=False):
|
117 |
+
try:
|
118 |
+
pl = executeTower(p,0.05)
|
119 |
+
if pl is None or (not lenient and len(pl) == 0): return None
|
120 |
+
if len(pl) > 100 or towerLength(pl) > 360: return None
|
121 |
+
|
122 |
+
t = SupervisedTower("tower dream", p)
|
123 |
+
return t
|
124 |
+
except Exception as e:
|
125 |
+
return None
|
126 |
+
except: pass
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
def tower_options(parser):
|
131 |
+
parser.add_argument("--tasks",
|
132 |
+
choices=["old","new"],
|
133 |
+
default="old")
|
134 |
+
parser.add_argument("--visualize",
|
135 |
+
default=None, type=str)
|
136 |
+
parser.add_argument("--solutions",
|
137 |
+
default=None, type=str)
|
138 |
+
parser.add_argument("--split",
|
139 |
+
default=1., type=float)
|
140 |
+
parser.add_argument("--dream",
|
141 |
+
default=None, type=str)
|
142 |
+
parser.add_argument("--primitives",
|
143 |
+
default="old", type=str,
|
144 |
+
choices=["new", "old"])
|
145 |
+
|
146 |
+
|
147 |
+
def dreamOfTowers(grammar, prefix, N=250, make_montage=True):
|
148 |
+
request = arrow(ttower,ttower)
|
149 |
+
randomTowers = [tuple(centerTower(t))
|
150 |
+
for _ in range(N)
|
151 |
+
for program in [grammar.sample(request,
|
152 |
+
maximumDepth=12,
|
153 |
+
maxAttempts=100)]
|
154 |
+
if program is not None
|
155 |
+
for t in [executeTower(program, timeout=0.5) or []]
|
156 |
+
if len(t) >= 1 and len(t) < 100 and towerLength(t) <= 360.]
|
157 |
+
matrix = [renderPlan(p,Lego=True,pretty=True)
|
158 |
+
for p in randomTowers]
|
159 |
+
|
160 |
+
# Only visualize if it has something to visualize.
|
161 |
+
if len(matrix) > 0:
|
162 |
+
import scipy.misc
|
163 |
+
if make_montage:
|
164 |
+
matrix = montage(matrix)
|
165 |
+
scipy.misc.imsave('%s.png'%prefix, matrix)
|
166 |
+
else:
|
167 |
+
for n,i in enumerate(matrix):
|
168 |
+
scipy.misc.imsave(f'{prefix}/{n}.png', i)
|
169 |
+
else:
|
170 |
+
eprint("Tried to visualize dreams, but none to visualize.")
|
171 |
+
|
172 |
+
|
173 |
+
def visualizePrimitives(primitives, fn=None):
|
174 |
+
from itertools import product
|
175 |
+
#from pylab import imshow,show
|
176 |
+
|
177 |
+
from dreamcoder.domains.tower.towerPrimitives import _left,_right,_loop,_embed,_empty_tower,TowerState
|
178 |
+
_13 = Program.parse("1x3").value
|
179 |
+
_31 = Program.parse("3x1").value
|
180 |
+
|
181 |
+
r = lambda n,k: _right(2*n)(k)
|
182 |
+
l = lambda n,k: _left(2*n)(k)
|
183 |
+
_e = _embed
|
184 |
+
_lp = lambda n,b,k: _loop(n)(b)(k)
|
185 |
+
_arch = lambda k: l(1,_13(r(2,_13(l(1,_31(k))))))
|
186 |
+
_tallArch = lambda h,z,k: _lp(h, lambda _: _13(r(2,_13(l(2,z)))),
|
187 |
+
r(1,_31(k)))
|
188 |
+
|
189 |
+
matrix = []
|
190 |
+
for p in primitives:
|
191 |
+
if not p.isInvented: continue
|
192 |
+
eprint(p,":",p.tp)
|
193 |
+
t = p.tp
|
194 |
+
if t.returns() != ttower: continue
|
195 |
+
|
196 |
+
def argumentChoices(t):
|
197 |
+
if t == ttower:
|
198 |
+
return [_empty_tower]
|
199 |
+
elif t == tint:
|
200 |
+
return list(range(5))
|
201 |
+
elif t == arrow(ttower,ttower):
|
202 |
+
return [_arch,_13,_31]
|
203 |
+
else:
|
204 |
+
return []
|
205 |
+
|
206 |
+
ts = []
|
207 |
+
for arguments in product(*[argumentChoices(t) for t in t.functionArguments() ]):
|
208 |
+
t = p.evaluate([])
|
209 |
+
for a in arguments: t = t(a)
|
210 |
+
t = t(TowerState())[1]
|
211 |
+
ts.append(t)
|
212 |
+
|
213 |
+
if ts == []: continue
|
214 |
+
|
215 |
+
matrix.append([renderPlan(p,pretty=True)
|
216 |
+
for p in ts])
|
217 |
+
|
218 |
+
# Only visualize if it has something to visualize.
|
219 |
+
if len(matrix) > 0:
|
220 |
+
matrix = montageMatrix(matrix)
|
221 |
+
# imshow(matrix)
|
222 |
+
|
223 |
+
import scipy.misc
|
224 |
+
scipy.misc.imsave(fn, matrix)
|
225 |
+
# show()
|
226 |
+
else:
|
227 |
+
eprint("Tried to visualize primitives, but none to visualize.")
|
228 |
+
|
229 |
+
def animateSolutions(checkpoint):
|
230 |
+
with open(checkpoint,"rb") as handle: result = dill.load(handle)
|
231 |
+
for n,f in enumerate(result.taskSolutions.values()):
|
232 |
+
animateTower(f"/tmp/tower_animation_{n}",f.bestPosterior.program)
|
233 |
+
|
234 |
+
def visualizeSolutions(solutions, export, tasks=None):
|
235 |
+
|
236 |
+
if tasks is None:
|
237 |
+
tasks = list(solutions.keys())
|
238 |
+
tasks.sort(key=lambda t: len(t.plan))
|
239 |
+
|
240 |
+
matrix = []
|
241 |
+
for t in tasks:
|
242 |
+
i = renderPlan(centerTower(t.plan),pretty=True,Lego=True)
|
243 |
+
if solutions[t].empty: i = i/3.
|
244 |
+
matrix.append(i)
|
245 |
+
|
246 |
+
# Only visualize if it has something to visualize.
|
247 |
+
if len(matrix) > 0:
|
248 |
+
matrix = montage(matrix)
|
249 |
+
import scipy.misc
|
250 |
+
scipy.misc.imsave(export, matrix)
|
251 |
+
else:
|
252 |
+
eprint("Tried to visualize solutions, but none to visualize.")
|
253 |
+
|
254 |
+
|
255 |
+
def main(arguments):
|
256 |
+
"""
|
257 |
+
Takes the return value of the `commandlineArguments()` function as input and
|
258 |
+
trains/tests the model on a set of tower-building tasks.
|
259 |
+
"""
|
260 |
+
|
261 |
+
# The below global statement is required since primitives is modified within main().
|
262 |
+
# TODO(lcary): use a function call to retrieve and declare primitives instead.
|
263 |
+
global primitives
|
264 |
+
|
265 |
+
import scipy.misc
|
266 |
+
|
267 |
+
g0 = Grammar.uniform({"new": new_primitives,
|
268 |
+
"old": primitives}[arguments.pop("primitives")],
|
269 |
+
continuationType=ttower)
|
270 |
+
|
271 |
+
checkpoint = arguments.pop("visualize")
|
272 |
+
if checkpoint is not None:
|
273 |
+
with open(checkpoint,'rb') as handle:
|
274 |
+
primitives = pickle.load(handle).grammars[-1].primitives
|
275 |
+
visualizePrimitives(primitives)
|
276 |
+
sys.exit(0)
|
277 |
+
checkpoint = arguments.pop("solutions")
|
278 |
+
if checkpoint is not None:
|
279 |
+
with open(checkpoint,'rb') as handle:
|
280 |
+
solutions = pickle.load(handle).taskSolutions
|
281 |
+
visualizeSolutions(solutions,
|
282 |
+
checkpoint + ".solutions.png")
|
283 |
+
animateSolutions(checkpoint)
|
284 |
+
sys.exit(0)
|
285 |
+
checkpoint = arguments.pop("dream")
|
286 |
+
if checkpoint is not None:
|
287 |
+
with open(checkpoint,'rb') as handle:
|
288 |
+
g = pickle.load(handle).grammars[-1]
|
289 |
+
os.system("mkdir -p data/tower_dreams")
|
290 |
+
dreamOfTowers(g, "data/tower_dreams", make_montage=False)
|
291 |
+
sys.exit(0)
|
292 |
+
|
293 |
+
|
294 |
+
tasks = arguments.pop("tasks")
|
295 |
+
if tasks == "new":
|
296 |
+
tasks = makeSupervisedTasks()
|
297 |
+
elif tasks == "old":
|
298 |
+
tasks = makeOldSupervisedTasks()
|
299 |
+
else: assert False
|
300 |
+
|
301 |
+
test, train = testTrainSplit(tasks, arguments.pop("split"))
|
302 |
+
eprint("Split %d/%d test/train" % (len(test), len(train)))
|
303 |
+
|
304 |
+
# Make a montage for the paper
|
305 |
+
shuffledTrain = list(train)
|
306 |
+
shuffledTest = list(test)
|
307 |
+
random.shuffle(shuffledTrain)
|
308 |
+
shuffledTrain = shuffledTrain + [None]*(60 - len(shuffledTrain))
|
309 |
+
random.shuffle(shuffledTest)
|
310 |
+
shuffledTest = shuffledTest + [None]*(60 - len(shuffledTest))
|
311 |
+
try:
|
312 |
+
SupervisedTower.exportMany("/tmp/every_tower.png",shuffledTrain + shuffledTest, shuffle=False, columns=10)
|
313 |
+
for j,task in enumerate(tasks):
|
314 |
+
task.exportImage(f"/tmp/tower_task_{j}.png")
|
315 |
+
for k,v in dSLDemo().items():
|
316 |
+
scipy.misc.imsave(f"/tmp/tower_dsl_{k}.png", v)
|
317 |
+
os.system(f"convert /tmp/tower_dsl_{k}.png -channel RGB -negate /tmp/tower_dsl_{k}.png")
|
318 |
+
except:
|
319 |
+
eprint("WARNING: can't export images. scipy needs to be an older version")
|
320 |
+
|
321 |
+
|
322 |
+
timestamp = datetime.datetime.now().isoformat()
|
323 |
+
outputDirectory = "experimentOutputs/towers/%s"%timestamp
|
324 |
+
os.system("mkdir -p %s"%outputDirectory)
|
325 |
+
|
326 |
+
os.system("mkdir -p data/tower_dreams_initial")
|
327 |
+
try:
|
328 |
+
dreamOfTowers(g0, "data/tower_dreams_initial", make_montage=False)
|
329 |
+
dreamOfTowers(g0, "%s/random_0"%outputDirectory)
|
330 |
+
except:
|
331 |
+
eprint("WARNING: can't export images. scipy needs to be an older version")
|
332 |
+
|
333 |
+
evaluationTimeout = 0.005
|
334 |
+
generator = ecIterator(g0, train,
|
335 |
+
testingTasks=test,
|
336 |
+
outputPrefix="%s/tower"%outputDirectory,
|
337 |
+
evaluationTimeout=evaluationTimeout,
|
338 |
+
**arguments)
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
for result in generator:
|
343 |
+
continue
|
344 |
+
iteration = len(result.learningCurve)
|
345 |
+
newTowers = [tuple(centerTower(executeTower(frontier.sample().program)))
|
346 |
+
for frontier in result.taskSolutions.values() if not frontier.empty]
|
347 |
+
try:
|
348 |
+
fn = '%s/solutions_%d.png'%(outputDirectory,iteration)
|
349 |
+
visualizeSolutions(result.taskSolutions, fn,
|
350 |
+
train)
|
351 |
+
eprint("Exported solutions to %s\n"%fn)
|
352 |
+
dreamOfTowers(result.grammars[-1],
|
353 |
+
'%s/random_%d'%(outputDirectory,iteration))
|
354 |
+
except ImportError:
|
355 |
+
eprint("Could not import required libraries for exporting towers.")
|
356 |
+
primitiveFilename = '%s/primitives_%d.png'%(outputDirectory, iteration)
|
357 |
+
visualizePrimitives(result.grammars[-1].primitives,
|
358 |
+
primitiveFilename)
|
359 |
+
eprint("Exported primitives to",primitiveFilename)
|
dreamcoder/domains/tower/makeTowerTasks.py
ADDED
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.domains.tower.towerPrimitives import ttower, executeTower, _empty_tower, TowerState
|
2 |
+
from dreamcoder.domains.tower.tower_common import renderPlan
|
3 |
+
from dreamcoder.task import *
|
4 |
+
|
5 |
+
|
6 |
+
class SupervisedTower(Task):
|
7 |
+
def __init__(self, name, program, mustTrain=False):
|
8 |
+
if isinstance(program,str):
|
9 |
+
try:
|
10 |
+
program = parseTower(program)
|
11 |
+
except:
|
12 |
+
eprint("Parse failure:")
|
13 |
+
eprint(program)
|
14 |
+
assert False
|
15 |
+
self.original = program
|
16 |
+
plan = executeTower(program)
|
17 |
+
elif isinstance(program,Program):
|
18 |
+
self.original = program
|
19 |
+
plan = executeTower(program)
|
20 |
+
else:
|
21 |
+
plan = program
|
22 |
+
self.original = program
|
23 |
+
state, self.plan = program.evaluate([])(_empty_tower)(TowerState())
|
24 |
+
self.hand = state.hand
|
25 |
+
super(SupervisedTower, self).__init__(name, arrow(ttower,ttower), [],
|
26 |
+
features=[])
|
27 |
+
self.specialTask = ("supervisedTower",
|
28 |
+
{"plan": self.plan})
|
29 |
+
self.image = None
|
30 |
+
self.handImage = None
|
31 |
+
self.mustTrain = mustTrain
|
32 |
+
|
33 |
+
def getImage(self, drawHand=False, pretty=False):
|
34 |
+
if not drawHand:
|
35 |
+
if not pretty:
|
36 |
+
if self.image is not None: return self.image
|
37 |
+
self.image = renderPlan(self.plan, pretty=pretty)
|
38 |
+
return self.image
|
39 |
+
else:
|
40 |
+
return renderPlan(self.plan, pretty=True)
|
41 |
+
else:
|
42 |
+
if self.handImage is not None: return self.handImage
|
43 |
+
self.handImage = renderPlan(self.plan,
|
44 |
+
drawHand=self.hand,
|
45 |
+
pretty=pretty)
|
46 |
+
return self.handImage
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
# do not pickle the image
|
51 |
+
def __getstate__(self):
|
52 |
+
return self.specialTask, self.plan, self.request, self.cache, self.name, self.examples
|
53 |
+
def __setstate__(self, state):
|
54 |
+
self.specialTask, self.plan, self.request, self.cache, self.name, self.examples = state
|
55 |
+
self.image = None
|
56 |
+
|
57 |
+
|
58 |
+
def animate(self):
|
59 |
+
from pylab import imshow,show
|
60 |
+
a = renderPlan(self.plan)
|
61 |
+
imshow(a)
|
62 |
+
show()
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def showMany(ts):
|
66 |
+
from pylab import imshow,show
|
67 |
+
a = montage([renderPlan(t.plan, pretty=True, Lego=True, resolution=256,
|
68 |
+
drawHand=False)
|
69 |
+
for t in ts])
|
70 |
+
imshow(a)
|
71 |
+
show()
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def exportMany(f, ts, shuffle=True, columns=None):
|
75 |
+
import numpy as np
|
76 |
+
|
77 |
+
ts = list(ts)
|
78 |
+
if shuffle:
|
79 |
+
assert all( t is not None for t in ts )
|
80 |
+
random.shuffle(ts)
|
81 |
+
a = montage([renderPlan(t.plan, pretty=True, Lego=True, resolution=256) if t is not None \
|
82 |
+
else np.zeros((256,256,3))
|
83 |
+
for t in ts],
|
84 |
+
columns=columns)
|
85 |
+
import scipy.misc
|
86 |
+
scipy.misc.imsave(f, a)
|
87 |
+
|
88 |
+
|
89 |
+
def exportImage(self, f, pretty=True, Lego=True, drawHand=False):
|
90 |
+
a = renderPlan(self.plan,
|
91 |
+
pretty=pretty, Lego=Lego,
|
92 |
+
drawHand=t.hand if drawHand else None)
|
93 |
+
import scipy.misc
|
94 |
+
scipy.misc.imsave(f, a)
|
95 |
+
|
96 |
+
def logLikelihood(self, e, timeout=None):
|
97 |
+
from dreamcoder.domains.tower.tower_common import centerTower
|
98 |
+
yh = executeTower(e, timeout)
|
99 |
+
if yh is not None and centerTower(yh) == centerTower(self.plan): return 0.
|
100 |
+
return NEGATIVEINFINITY
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
def parseTower(s):
|
105 |
+
_13 = Program.parse("1x3")
|
106 |
+
_31 = Program.parse("3x1")
|
107 |
+
_r = Program.parse("right")
|
108 |
+
_l = Program.parse("left")
|
109 |
+
_addition = Program.parse("+")
|
110 |
+
_subtraction = Program.parse("-")
|
111 |
+
_lp = Program.parse("tower_loopM")
|
112 |
+
_e = Program.parse("tower_embed")
|
113 |
+
|
114 |
+
from sexpdata import loads, Symbol
|
115 |
+
s = loads(s)
|
116 |
+
def command(k, environment, continuation):
|
117 |
+
if k == Symbol("1x3") or k == Symbol("v"): return Application(_13, continuation)
|
118 |
+
if k == Symbol("3x1") or k == Symbol("h"): return Application(_31, continuation)
|
119 |
+
assert isinstance(k,list)
|
120 |
+
if k[0] == Symbol("r"): return Application(Application(_r, expression(k[1],environment)),continuation)
|
121 |
+
if k[0] == Symbol("l"): return Application(Application(_l, expression(k[1],environment)),continuation)
|
122 |
+
if k[0] == Symbol("for"):
|
123 |
+
v = k[1]
|
124 |
+
b = expression(k[2], environment)
|
125 |
+
newEnvironment = [None, v] + environment
|
126 |
+
body = block(k[3:], newEnvironment, Index(0))
|
127 |
+
return Application(Application(Application(_lp,b),
|
128 |
+
Abstraction(Abstraction(body))),
|
129 |
+
continuation)
|
130 |
+
if k[0] == Symbol("embed"):
|
131 |
+
body = block(k[1:], [None] + environment, Index(0))
|
132 |
+
return Application(Application(_e,Abstraction(body)),continuation)
|
133 |
+
|
134 |
+
assert False
|
135 |
+
def expression(e, environment):
|
136 |
+
for n, v in enumerate(environment):
|
137 |
+
if e == v: return Index(n)
|
138 |
+
|
139 |
+
if isinstance(e,int): return Program.parse(str(e))
|
140 |
+
|
141 |
+
assert isinstance(e,list)
|
142 |
+
if e[0] == Symbol('+'): return Application(Application(_addition, expression(e[1], environment)),
|
143 |
+
expression(e[2], environment))
|
144 |
+
if e[0] == Symbol('-'): return Application(Application(_subtraction, expression(e[1], environment)),
|
145 |
+
expression(e[2], environment))
|
146 |
+
assert False
|
147 |
+
|
148 |
+
def block(b, environment, continuation):
|
149 |
+
if len(b) == 0: return continuation
|
150 |
+
return command(b[0], environment, block(b[1:], environment, continuation))
|
151 |
+
|
152 |
+
try: return Abstraction(command(s, [], Index(0)))
|
153 |
+
except: return Abstraction(block(s, [], Index(0)))
|
154 |
+
|
155 |
+
|
156 |
+
def makeSupervisedTasks():
|
157 |
+
arches = [SupervisedTower("arch leg %d"%n,
|
158 |
+
"((for i %d v) (r 4) (for i %d v) (l 2) h)"%(n,n))
|
159 |
+
for n in range(1,9)
|
160 |
+
]
|
161 |
+
archesStacks = [SupervisedTower("arch stack %d"%n,
|
162 |
+
"""
|
163 |
+
(for i %d
|
164 |
+
v (r 4) v (l 2) h (l 2))
|
165 |
+
"""%n)
|
166 |
+
for n in range(3,7) ]
|
167 |
+
Bridges = [SupervisedTower("bridge (%d) of arch %d"%(n,l),
|
168 |
+
"""
|
169 |
+
(for j %d
|
170 |
+
(for i %d
|
171 |
+
v (r 4) v (l 4)) (r 2) h
|
172 |
+
(r 4))
|
173 |
+
"""%(n,l))
|
174 |
+
for n in range(2,8)
|
175 |
+
for l in range(1,6)]
|
176 |
+
offsetArches = [SupervisedTower("bridge (%d) of arch, spaced %d"%(n,l),
|
177 |
+
"""
|
178 |
+
(for j %d
|
179 |
+
(embed v (r 4) v (l 2) h )
|
180 |
+
(r %d))
|
181 |
+
"""%(n,l),
|
182 |
+
mustTrain=n == 3)
|
183 |
+
for n,l in [(3,7),(4,8)]]
|
184 |
+
Josh = [SupervisedTower("Josh (%d)"%n,
|
185 |
+
"""(for i %d
|
186 |
+
h (l 2) v (r 2) v (r 2) v (l 2) h (r 6))"""%n)
|
187 |
+
for n in range(1,7) ]
|
188 |
+
|
189 |
+
staircase1 = [SupervisedTower("R staircase %d"%n,
|
190 |
+
"""
|
191 |
+
(for i %d (for j i
|
192 |
+
(embed v (r 4) v (l 2) h)) (r 6))
|
193 |
+
"""%(n))
|
194 |
+
for n in range(3,8) ]
|
195 |
+
staircase2 = [SupervisedTower("L staircase %d"%n,
|
196 |
+
"""
|
197 |
+
(for i %d (for j i
|
198 |
+
(embed v (r 4) v (l 2) h)) (l 6))
|
199 |
+
"""%(n))
|
200 |
+
for n in range(3,8) ]
|
201 |
+
simpleLoops = [SupervisedTower("%s row %d, spacing %d"%(o,n,s),
|
202 |
+
"""(for j %d %s (r %s))"""%(n,o,s),
|
203 |
+
mustTrain=True)
|
204 |
+
for o,n,s in [('h',4,7), ('v',5,3)] ]
|
205 |
+
|
206 |
+
pyramids = []
|
207 |
+
pyramids += [SupervisedTower("arch pyramid %d"%n,
|
208 |
+
"""((for i %d (for j i (embed v (r 4) v (l 2) h)) (r 6))
|
209 |
+
(for i %d (for j (- %d i) (embed v (r 4) v (l 2) h)) (r 6)))"""%(n,n,n))
|
210 |
+
for n in range(2,6) ]
|
211 |
+
pyramids += [SupervisedTower("H pyramid %d"%n,
|
212 |
+
"""((for i %d (for j i h) (r 6))
|
213 |
+
(for i %d (for j (- %d i) h) (r 6)))"""%(n,n,n))
|
214 |
+
for n in range(4,6) ]
|
215 |
+
# pyramids += [SupervisedTower("V pyramid %d"%n,
|
216 |
+
# """
|
217 |
+
# ((for i %d (for j i v) (r 2))
|
218 |
+
# (for i %d (for j (- %d i) v) (r 2)))
|
219 |
+
# """%(n,n,n))
|
220 |
+
# for n in range(4,8) ]
|
221 |
+
# pyramids += [SupervisedTower("V3 pyramid %d"%n,
|
222 |
+
# """
|
223 |
+
# ((for i %d (for j i v) (r 6))
|
224 |
+
# (for i %d (for j (- %d i) v) (r 6)))
|
225 |
+
# """%(n,n,n))
|
226 |
+
# for n in range(4,8) ]
|
227 |
+
pyramids += [SupervisedTower("H 1/2 pyramid %d"%n,
|
228 |
+
"""
|
229 |
+
(for i %d
|
230 |
+
(r 6)
|
231 |
+
(embed
|
232 |
+
(for j i h (l 3))))
|
233 |
+
"""%n)
|
234 |
+
for n in range(4,8) ]
|
235 |
+
pyramids += [SupervisedTower("arch 1/2 pyramid %d"%n,
|
236 |
+
"""
|
237 |
+
(for i %d
|
238 |
+
(r 6)
|
239 |
+
(embed
|
240 |
+
(for j i (embed v (r 4) v (l 2) h) (l 3))))
|
241 |
+
"""%n)
|
242 |
+
for n in range(2,8) ]
|
243 |
+
if False:
|
244 |
+
pyramids += [SupervisedTower("V 1/2 pyramid %d"%n,
|
245 |
+
"""
|
246 |
+
(for i %d
|
247 |
+
(r 2)
|
248 |
+
(embed
|
249 |
+
(for j i v (l 1))))"""%(n))
|
250 |
+
for n in range(4,8) ]
|
251 |
+
bricks = [SupervisedTower("brickwall, %dx%d"%(w,h),
|
252 |
+
"""(for j %d
|
253 |
+
(embed (for i %d h (r 6)))
|
254 |
+
(embed (r 3) (for i %d h (r 6))))"""%(h,w,w))
|
255 |
+
for w in range(3,7)
|
256 |
+
for h in range(1,6) ]
|
257 |
+
aqueducts = [SupervisedTower("aqueduct: %dx%d"%(w,h),
|
258 |
+
"""(for j %d
|
259 |
+
%s (r 4) %s (l 2) h (l 2) v (r 4) v (l 2) h (r 4))"""%
|
260 |
+
(w, "v "*h, "v "*h))
|
261 |
+
for w in range(4,8)
|
262 |
+
for h in range(3,6)
|
263 |
+
]
|
264 |
+
|
265 |
+
compositions = [SupervisedTower("%dx%d-bridge on top of %dx%d bricks"%(b1,b2,w1,w2),
|
266 |
+
"""
|
267 |
+
((for j %d
|
268 |
+
(embed (for i %d h (r 6)))
|
269 |
+
(embed (r 3) (for i %d h (r 6))))
|
270 |
+
(r 1)
|
271 |
+
(for j %d
|
272 |
+
(for i %d
|
273 |
+
v (r 4) v (l 4)) (r 2) h
|
274 |
+
(r 4)))
|
275 |
+
"""%(w1,w2,w2,b1,b2))
|
276 |
+
for b1,b2,w1,w2 in [(5,2,4,5)]
|
277 |
+
] + [
|
278 |
+
SupervisedTower("%d pyramid on top of %dx%d bricks"%(p,w1,w2),
|
279 |
+
"""
|
280 |
+
((for j %d
|
281 |
+
(embed (for i %d h (r 6)))
|
282 |
+
(embed (r 3) (for i %d h (r 6))))
|
283 |
+
(r 1)
|
284 |
+
(for i %d (for j i (embed v (r 4) v (l 2) h)) (r 6))
|
285 |
+
(for i %d (for j (- %d i) (embed v (r 4) v (l 2) h)) (r 6)))
|
286 |
+
"""%(w1,w2,w2,p,p,p))
|
287 |
+
for w1,w2,p in [(2,5,2)]
|
288 |
+
] + \
|
289 |
+
[
|
290 |
+
SupervisedTower("%d tower on top of %dx%d bricks"%(t,w1,w2),
|
291 |
+
"""
|
292 |
+
((for j %d
|
293 |
+
(embed (for i %d h (r 6)))
|
294 |
+
(embed (r 3) (for i %d h (r 6))))
|
295 |
+
(r 6)
|
296 |
+
%s (r 4) %s (l 2) h)
|
297 |
+
"""%(w1,w2,w2,
|
298 |
+
"v "*t, "v "*t))
|
299 |
+
for t,w1,w2 in [(4,1,3)] ]
|
300 |
+
|
301 |
+
|
302 |
+
|
303 |
+
everything = arches + simpleLoops + Bridges + archesStacks + aqueducts + offsetArches + pyramids + bricks + staircase2 + staircase1 + compositions
|
304 |
+
if False:
|
305 |
+
for t in everything:
|
306 |
+
delattr(t,'original')
|
307 |
+
return everything
|
308 |
+
|
309 |
+
def makeOldSupervisedTasks():
|
310 |
+
arches = [SupervisedTower("arch leg %d"%n,
|
311 |
+
"((for i %d v) (r 4) (for i %d v) (l 2) h)"%(n,n))
|
312 |
+
for n in range(1,9)
|
313 |
+
]
|
314 |
+
archesStacks = [SupervisedTower("arch stack %d"%n,
|
315 |
+
"""
|
316 |
+
(for i %d
|
317 |
+
v (r 4) v (l 2) h (l 2))
|
318 |
+
"""%n)
|
319 |
+
for n in range(3,7) ]
|
320 |
+
Bridges = [SupervisedTower("bridge (%d) of arch %d"%(n,l),
|
321 |
+
"""
|
322 |
+
(for j %d
|
323 |
+
(for i %d
|
324 |
+
v (r 4) v (l 4)) (r 2) h
|
325 |
+
(r 4))
|
326 |
+
"""%(n,l))
|
327 |
+
for n in range(2,8)
|
328 |
+
for l in range(1,6)]
|
329 |
+
offsetArches = [SupervisedTower("bridge (%d) of arch, spaced %d"%(n,l),
|
330 |
+
"""
|
331 |
+
(for j %d
|
332 |
+
v (r 4) v (l 2) h
|
333 |
+
(r %d))
|
334 |
+
"""%(n,l))
|
335 |
+
for n,l in [(3,7),(4,6)]]
|
336 |
+
Josh = [SupervisedTower("Josh (%d)"%n,
|
337 |
+
"""(for i %d
|
338 |
+
h (l 2) v (r 2) v (r 2) v (l 2) h (r 6))"""%n)
|
339 |
+
for n in range(1,7) ]
|
340 |
+
|
341 |
+
staircase1 = [SupervisedTower("R staircase %d"%n,
|
342 |
+
"""
|
343 |
+
(for i %d (for j i
|
344 |
+
(embed v (r 4) v (l 2) h)) (r 6))
|
345 |
+
"""%(n))
|
346 |
+
for n in range(3,8) ]
|
347 |
+
staircase2 = [SupervisedTower("L staircase %d"%n,
|
348 |
+
"""
|
349 |
+
(for i %d (for j i
|
350 |
+
(embed v (r 4) v (l 2) h)) (l 6))
|
351 |
+
"""%(n))
|
352 |
+
for n in range(3,8) ]
|
353 |
+
simpleLoops = [SupervisedTower("horizontal row %d, spacing %d"%(n,s),
|
354 |
+
"""(for j %d h (r %s))"""%(n,s))
|
355 |
+
for n,s in [(4,6),(5,7)] ]+\
|
356 |
+
[SupervisedTower("horizontal stack %d"%n,
|
357 |
+
"""(for j %d h)"""%n)
|
358 |
+
for n in range(5,8) ]+\
|
359 |
+
[SupervisedTower("vertical stack %d"%n,
|
360 |
+
"""(for j %d v)"""%n)
|
361 |
+
for n in [5,7] ]
|
362 |
+
pyramids = []
|
363 |
+
pyramids += [SupervisedTower("arch pyramid %d"%n,
|
364 |
+
"""((for i %d (for j i (embed v (r 4) v (l 2) h)) (r 6))
|
365 |
+
(for i %d (for j (- %d i) (embed v (r 4) v (l 2) h)) (r 6)))"""%(n,n,n))
|
366 |
+
for n in range(2,6) ]
|
367 |
+
pyramids += [SupervisedTower("H pyramid %d"%n,
|
368 |
+
"""((for i %d (for j i h) (r 6))
|
369 |
+
(for i %d (for j (- %d i) h) (r 6)))"""%(n,n,n))
|
370 |
+
for n in range(4,6) ]
|
371 |
+
# pyramids += [SupervisedTower("V pyramid %d"%n,
|
372 |
+
# """
|
373 |
+
# ((for i %d (for j i v) (r 2))
|
374 |
+
# (for i %d (for j (- %d i) v) (r 2)))
|
375 |
+
# """%(n,n,n))
|
376 |
+
# for n in range(4,8) ]
|
377 |
+
# pyramids += [SupervisedTower("V3 pyramid %d"%n,
|
378 |
+
# """
|
379 |
+
# ((for i %d (for j i v) (r 6))
|
380 |
+
# (for i %d (for j (- %d i) v) (r 6)))
|
381 |
+
# """%(n,n,n))
|
382 |
+
# for n in range(4,8) ]
|
383 |
+
pyramids += [SupervisedTower("H 1/2 pyramid %d"%n,
|
384 |
+
"""
|
385 |
+
(for i %d
|
386 |
+
(r 6)
|
387 |
+
(embed
|
388 |
+
(for j i h (l 3))))
|
389 |
+
"""%n)
|
390 |
+
for n in range(4,8) ]
|
391 |
+
pyramids += [SupervisedTower("arch 1/2 pyramid %d"%n,
|
392 |
+
"""
|
393 |
+
(for i %d
|
394 |
+
(r 6)
|
395 |
+
(embed
|
396 |
+
(for j i (embed v (r 4) v (l 2) h) (l 3))))
|
397 |
+
"""%n)
|
398 |
+
for n in range(2,8) ]
|
399 |
+
if False:
|
400 |
+
pyramids += [SupervisedTower("V 1/2 pyramid %d"%n,
|
401 |
+
"""
|
402 |
+
(for i %d
|
403 |
+
(r 2)
|
404 |
+
(embed
|
405 |
+
(for j i v (l 1))))"""%(n))
|
406 |
+
for n in range(4,8) ]
|
407 |
+
bricks = [SupervisedTower("brickwall, %dx%d"%(w,h),
|
408 |
+
"""(for j %d
|
409 |
+
(embed (for i %d h (r 6)))
|
410 |
+
(embed (r 3) (for i %d h (r 6))))"""%(h,w,w))
|
411 |
+
for w in range(3,7)
|
412 |
+
for h in range(1,6) ]
|
413 |
+
aqueducts = [SupervisedTower("aqueduct: %dx%d"%(w,h),
|
414 |
+
"""(for j %d
|
415 |
+
%s (r 4) %s (l 2) h (l 2) v (r 4) v (l 2) h (r 4))"""%
|
416 |
+
(w, "v "*h, "v "*h))
|
417 |
+
for w in range(4,8)
|
418 |
+
for h in range(3,6)
|
419 |
+
]
|
420 |
+
|
421 |
+
compositions = [SupervisedTower("%dx%d-bridge on top of %dx%d bricks"%(b1,b2,w1,w2),
|
422 |
+
"""
|
423 |
+
((for j %d
|
424 |
+
(embed (for i %d h (r 6)))
|
425 |
+
(embed (r 3) (for i %d h (r 6))))
|
426 |
+
(r 1)
|
427 |
+
(for j %d
|
428 |
+
(for i %d
|
429 |
+
v (r 4) v (l 4)) (r 2) h
|
430 |
+
(r 4)))
|
431 |
+
"""%(w1,w2,w2,b1,b2))
|
432 |
+
for b1,b2,w1,w2 in [(5,2,4,5)]
|
433 |
+
] + [
|
434 |
+
SupervisedTower("%d pyramid on top of %dx%d bricks"%(p,w1,w2),
|
435 |
+
"""
|
436 |
+
((for j %d
|
437 |
+
(embed (for i %d h (r 6)))
|
438 |
+
(embed (r 3) (for i %d h (r 6))))
|
439 |
+
(r 1)
|
440 |
+
(for i %d (for j i (embed v (r 4) v (l 2) h)) (r 6))
|
441 |
+
(for i %d (for j (- %d i) (embed v (r 4) v (l 2) h)) (r 6)))
|
442 |
+
"""%(w1,w2,w2,p,p,p))
|
443 |
+
for w1,w2,p in [(2,5,2)]
|
444 |
+
] + \
|
445 |
+
[
|
446 |
+
SupervisedTower("%d tower on top of %dx%d bricks"%(t,w1,w2),
|
447 |
+
"""
|
448 |
+
((for j %d
|
449 |
+
(embed (for i %d h (r 6)))
|
450 |
+
(embed (r 3) (for i %d h (r 6))))
|
451 |
+
(r 6)
|
452 |
+
%s (r 4) %s (l 2) h)
|
453 |
+
"""%(w1,w2,w2,
|
454 |
+
"v "*t, "v "*t))
|
455 |
+
for t,w1,w2 in [(4,1,3)] ]
|
456 |
+
|
457 |
+
|
458 |
+
|
459 |
+
everything = arches + simpleLoops + Bridges + archesStacks + aqueducts + offsetArches + pyramids + bricks + staircase2 + staircase1 + compositions
|
460 |
+
if False:
|
461 |
+
for t in everything:
|
462 |
+
delattr(t,'original')
|
463 |
+
return everything
|
464 |
+
|
465 |
+
def dSLDemo():
|
466 |
+
DSL = {}
|
467 |
+
bricks = Program.parse("(lambda (lambda (tower_loopM $0 (lambda (lambda (moveHand 3 (reverseHand (tower_loopM $3 (lambda (lambda (moveHand 6 (3x1 $0)))) $0))))))))")
|
468 |
+
DSL["bricks"] = [ [bricks.runWithArguments([x,y + 4,_empty_tower,TowerState()])[1]
|
469 |
+
for y in range(6, 6 + 3*4, 3) ]
|
470 |
+
for x in [3,8] ]
|
471 |
+
dimensionality = {}
|
472 |
+
dimensionality["bricks"] = 2
|
473 |
+
|
474 |
+
bridge = Program.parse("(lambda (lambda (tower_loopM $0 (lambda (lambda (#(lambda (#(lambda (lambda (lambda (tower_loopM $0 (lambda (lambda (1x3 (moveHand 4 ($3 $0))))) (moveHand 2 (3x1 $2)))))) $0 (lambda (reverseHand $0)))) (moveHand 4 $0) $3))))))")
|
475 |
+
DSL["bridge"] = [ [bridge.runWithArguments([x,y,_empty_tower,TowerState()])[1]
|
476 |
+
for x in range(4,4 + 2*4,2) ]
|
477 |
+
for y in [4,9] ]
|
478 |
+
dimensionality["bridge"] = 2
|
479 |
+
|
480 |
+
staircase = Program.parse("(lambda (tower_loopM $0 (lambda (lambda (#(lambda (lambda (tower_loopM $1 (lambda (lambda (tower_embed (lambda (#(lambda (1x3 (moveHand 4 (1x3 (reverseHand (moveHand 2 (3x1 $0))))))) $0)) $0))) $0))) $1 (moveHand 6 $0))))))")
|
481 |
+
DSL["staircase"] = [ staircase.runWithArguments([n,_empty_tower,TowerState()])[1]
|
482 |
+
for n in range(4,5 + 3) ]
|
483 |
+
|
484 |
+
pyramid = Program.parse("(lambda (tower_loopM $0 (lambda (lambda (moveHand 6 (tower_embed (lambda (reverseHand ((lambda (lambda (tower_loopM $1 (lambda (lambda (moveHand $2 (1x3 (moveHand 2 (tower_embed (lambda (moveHand 2 (1x3 $0))) (3x1 $0)))))))))) $2 1 $0))) $0))))))")
|
485 |
+
DSL["pyramid"] = [ pyramid.runWithArguments([n,_empty_tower,TowerState()])[1]
|
486 |
+
for n in range(4,5 + 3) ]
|
487 |
+
|
488 |
+
towerArch = Program.parse("(lambda (lambda ((lambda ((lambda (lambda (lambda (tower_loopM $0 (lambda (lambda (1x3 (moveHand 4 ($3 $0))))) (moveHand 2 (3x1 $2)))))) $0 (lambda (reverseHand (1x3 $0))))) $0 $1)))")
|
489 |
+
DSL["towerArch"] = [ towerArch.runWithArguments([n,_empty_tower,TowerState()])[1]
|
490 |
+
for n in range(4,5 + 3) ]
|
491 |
+
|
492 |
+
images = {}
|
493 |
+
for k,v in DSL.items():
|
494 |
+
d = dimensionality.get(k,1)
|
495 |
+
if d == 1:
|
496 |
+
i = montageMatrix([[renderPlan(p, pretty=True, Lego=True) for p in v]])
|
497 |
+
elif d == 2:
|
498 |
+
i = montageMatrix([[renderPlan(p, pretty=True, Lego=True) for p in ps] for ps in v] )
|
499 |
+
else: assert False
|
500 |
+
|
501 |
+
images[k] = i
|
502 |
+
|
503 |
+
return images
|
504 |
+
|
505 |
+
if __name__ == "__main__":
|
506 |
+
from pylab import imshow,show
|
507 |
+
from dreamcoder.domains.tower.tower_common import *
|
508 |
+
|
509 |
+
ts = makeSupervisedTasks()
|
510 |
+
print(len(ts),"total tasks")
|
511 |
+
print("maximum plan length",max(len(f.plan) for f in ts ))
|
512 |
+
print("maximum tower length",max(towerLength(f.plan) for f in ts ))
|
513 |
+
print("maximum tower height",max(towerHeight(simulateWithoutPhysics(f.plan)) for f in ts ))
|
514 |
+
SupervisedTower.exportMany("/tmp/every_tower.png",ts,shuffle=False)
|
515 |
+
|
516 |
+
for j,t in enumerate(ts):
|
517 |
+
t.exportImage("/tmp/tower_%d.png"%j,
|
518 |
+
drawHand=False)
|
519 |
+
|
520 |
+
for k,v in dSLDemo().items():
|
521 |
+
import scipy.misc
|
522 |
+
scipy.misc.imsave(f"/tmp/tower_dsl_{k}.png", v)
|
523 |
+
|
524 |
+
exampleTowers = [103,104,105,93,73,
|
525 |
+
50,67,35,43,106]
|
526 |
+
SupervisedTower.exportMany("/tmp/tower_montage.png",
|
527 |
+
[ts[n] for n in exampleTowers ],
|
528 |
+
columns=5,
|
529 |
+
shuffle=False)
|
530 |
+
assert False
|
531 |
+
|
532 |
+
|
533 |
+
keywords = ["pyramid",
|
534 |
+
"on top of",
|
535 |
+
"arch 1/2 pyramid",
|
536 |
+
"brickwall",
|
537 |
+
"staircase",
|
538 |
+
"bridge",
|
539 |
+
"aqueduct",
|
540 |
+
"spaced",
|
541 |
+
"spaced",
|
542 |
+
"arch stack"]
|
543 |
+
for n in range(100):
|
544 |
+
examples = []
|
545 |
+
for kw in keywords:
|
546 |
+
if kw == "on top of":
|
547 |
+
examples = examples + list(filter(lambda t: kw in str(t), ts))
|
548 |
+
else:
|
549 |
+
examples.append(random.choice(list(filter(lambda t: kw in str(t), ts))))
|
550 |
+
|
551 |
+
random.shuffle(examples)
|
552 |
+
SupervisedTower.exportMany("/tmp/tower10_%d.png"%n,examples,
|
553 |
+
columns=int(len(examples)/2))
|
554 |
+
|
555 |
+
|
556 |
+
|
dreamcoder/domains/tower/towerPrimitives.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.program import *
|
2 |
+
|
3 |
+
|
4 |
+
class TowerState:
|
5 |
+
def __init__(self, hand=0, orientation=1, history=None):
|
6 |
+
# List of (State|Block)
|
7 |
+
self.history = history
|
8 |
+
self.hand = hand
|
9 |
+
self.orientation = orientation
|
10 |
+
def __str__(self): return f"S(h={self.hand},o={self.orientation})"
|
11 |
+
def __repr__(self): return str(self)
|
12 |
+
def left(self, n):
|
13 |
+
return TowerState(hand=self.hand - n, orientation=self.orientation,
|
14 |
+
history=self.history if self.history is None \
|
15 |
+
else self.history + [self])
|
16 |
+
def right(self, n): return TowerState(hand=self.hand + n, orientation=self.orientation,
|
17 |
+
history=self.history if self.history is None \
|
18 |
+
else self.history + [self])
|
19 |
+
def reverse(self): return TowerState(hand=self.hand, orientation=-1*self.orientation,
|
20 |
+
history=self.history if self.history is None \
|
21 |
+
else self.history + [self])
|
22 |
+
def move(self, n): return TowerState(hand=self.hand + n*self.orientation, orientation=self.orientation,
|
23 |
+
history=self.history if self.history is None \
|
24 |
+
else self.history + [self])
|
25 |
+
|
26 |
+
def recordBlock(self, b):
|
27 |
+
if self.history is None: return self
|
28 |
+
return TowerState(hand=self.hand,
|
29 |
+
orientation=self.orientation,
|
30 |
+
history=self.history + [b])
|
31 |
+
|
32 |
+
|
33 |
+
def _empty_tower(h): return (h,[])
|
34 |
+
def _left(d):
|
35 |
+
return lambda k: lambda s: k(s.left(d))
|
36 |
+
def _right(d):
|
37 |
+
return lambda k: lambda s: k(s.right(d))
|
38 |
+
def _loop(n):
|
39 |
+
def f(start, stop, body, state):
|
40 |
+
if start >= stop: return state,[]
|
41 |
+
state, thisIteration = body(start)(state)
|
42 |
+
state, laterIterations = f(start + 1, stop, body, state)
|
43 |
+
return state, thisIteration + laterIterations
|
44 |
+
def sequence(b,k,h):
|
45 |
+
h,bodyBlocks = f(0,n,b,h)
|
46 |
+
h,laterBlocks = k(h)
|
47 |
+
return h,bodyBlocks+laterBlocks
|
48 |
+
return lambda b: lambda k: lambda h: sequence(b,k,h)
|
49 |
+
def _simpleLoop(n):
|
50 |
+
def f(start, body, k):
|
51 |
+
if start >= n: return k
|
52 |
+
return body(start)(f(start + 1, body, k))
|
53 |
+
return lambda b: lambda k: f(0,b,k)
|
54 |
+
def _embed(body):
|
55 |
+
def f(k):
|
56 |
+
def g(hand):
|
57 |
+
bodyHand, bodyActions = body(_empty_tower)(hand)
|
58 |
+
# Record history if we are doing that
|
59 |
+
if hand.history is not None:
|
60 |
+
hand = TowerState(hand=hand.hand,
|
61 |
+
orientation=hand.orientation,
|
62 |
+
history=bodyHand.history)
|
63 |
+
hand, laterActions = k(hand)
|
64 |
+
return hand, bodyActions + laterActions
|
65 |
+
return g
|
66 |
+
return f
|
67 |
+
def _moveHand(n):
|
68 |
+
return lambda k: lambda s: k(s.move(n))
|
69 |
+
def _reverseHand(k):
|
70 |
+
return lambda s: k(s.reverse())
|
71 |
+
|
72 |
+
class TowerContinuation(object):
|
73 |
+
def __init__(self, x, w, h):
|
74 |
+
self.x = x
|
75 |
+
self.w = w*2
|
76 |
+
self.h = h*2
|
77 |
+
def __call__(self, k):
|
78 |
+
def f(hand):
|
79 |
+
thisAction = [(self.x + hand.hand,self.w,self.h)]
|
80 |
+
hand = hand.recordBlock(thisAction[0])
|
81 |
+
hand, rest = k(hand)
|
82 |
+
return hand, thisAction + rest
|
83 |
+
return f
|
84 |
+
|
85 |
+
# name, dimensions
|
86 |
+
blocks = {
|
87 |
+
# "1x1": (1.,1.),
|
88 |
+
# "2x1": (2.,1.),
|
89 |
+
# "1x2": (1.,2.),
|
90 |
+
"3x1": (3, 1),
|
91 |
+
"1x3": (1, 3),
|
92 |
+
# "4x1": (4.,1.),
|
93 |
+
# "1x4": (1.,4.)
|
94 |
+
}
|
95 |
+
|
96 |
+
|
97 |
+
ttower = baseType("tower")
|
98 |
+
common_primitives = [
|
99 |
+
Primitive("tower_loopM", arrow(tint, arrow(tint, ttower, ttower), ttower, ttower), _simpleLoop),
|
100 |
+
Primitive("tower_embed", arrow(arrow(ttower,ttower), ttower, ttower), _embed),
|
101 |
+
] + [Primitive(name, arrow(ttower,ttower), TowerContinuation(0, w, h))
|
102 |
+
for name, (w, h) in blocks.items()] + \
|
103 |
+
[Primitive(str(j), tint, j) for j in range(1,9) ]
|
104 |
+
primitives = common_primitives + [
|
105 |
+
Primitive("left", arrow(tint, ttower, ttower), _left),
|
106 |
+
Primitive("right", arrow(tint, ttower, ttower), _right)
|
107 |
+
]
|
108 |
+
|
109 |
+
new_primitives = common_primitives + [
|
110 |
+
Primitive("moveHand", arrow(tint, ttower, ttower), _moveHand),
|
111 |
+
Primitive("reverseHand", arrow(ttower, ttower), _reverseHand)
|
112 |
+
]
|
113 |
+
|
114 |
+
def executeTower(p, timeout=None):
|
115 |
+
try:
|
116 |
+
return runWithTimeout(lambda : p.evaluate([])(_empty_tower)(TowerState())[1],
|
117 |
+
timeout=timeout)
|
118 |
+
except RunWithTimeout: return None
|
119 |
+
except: return None
|
120 |
+
|
121 |
+
def animateTower(exportPrefix, p):
|
122 |
+
print(exportPrefix, p)
|
123 |
+
from dreamcoder.domains.tower.tower_common import renderPlan
|
124 |
+
state,actions = p.evaluate([])(_empty_tower)(TowerState(history=[]))
|
125 |
+
print(actions)
|
126 |
+
trajectory = state.history + [state]
|
127 |
+
print(trajectory)
|
128 |
+
print()
|
129 |
+
|
130 |
+
assert tuple(z for z in trajectory if not isinstance(z, TowerState) ) == tuple(actions)
|
131 |
+
|
132 |
+
def hd(n):
|
133 |
+
h = 0
|
134 |
+
for state in trajectory[:n]:
|
135 |
+
if isinstance(state, TowerState):
|
136 |
+
h = state.hand
|
137 |
+
return h
|
138 |
+
animation = [renderPlan([b for b in trajectory[:n] if not isinstance(b, TowerState)],
|
139 |
+
pretty=True, Lego=True,
|
140 |
+
drawHand=hd(n),
|
141 |
+
masterPlan=actions,
|
142 |
+
randomSeed=hash(exportPrefix))
|
143 |
+
for n in range(0,len(trajectory) + 1)]
|
144 |
+
import scipy.misc
|
145 |
+
import random
|
146 |
+
r = random.random()
|
147 |
+
paths = []
|
148 |
+
for n in range(len(animation)):
|
149 |
+
paths.append(f"{exportPrefix}_{n}.png")
|
150 |
+
scipy.misc.imsave(paths[-1], animation[n])
|
151 |
+
os.system(f"convert -delay 10 -loop 0 {' '.join(paths)} {exportPrefix}.gif")
|
152 |
+
# os.system(f"rm {' '.join(paths)}")
|
dreamcoder/domains/tower/tower_common.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import math
|
3 |
+
from dreamcoder.utilities import *
|
4 |
+
|
5 |
+
def simulateWithoutPhysics(plan,ordered=True):
|
6 |
+
def overlap(b1,
|
7 |
+
b2):
|
8 |
+
(x,w,h) = b1
|
9 |
+
(x_,y_,w_,h_) = b2
|
10 |
+
x1 = x - w/2
|
11 |
+
x2 = x + w/2
|
12 |
+
x1_ = x_ - w_/2
|
13 |
+
x2_ = x_ + w_/2
|
14 |
+
if x1_ >= x2 or x1 >= x2_: return None
|
15 |
+
assert h%2 == 0 and h_%2 == 0
|
16 |
+
return y_ + h_//2 + h//2
|
17 |
+
def lowestPossibleHeight(b):
|
18 |
+
h = b[2]
|
19 |
+
assert h%2 == 0
|
20 |
+
return int(h/2)
|
21 |
+
def placeAtHeight(b,y):
|
22 |
+
(x,w,h) = b
|
23 |
+
return (x,y,w,h)
|
24 |
+
def placeBlock(world, block):
|
25 |
+
lowest = max([lowestPossibleHeight(block)] + \
|
26 |
+
[overlap(block,other)
|
27 |
+
for other in world
|
28 |
+
if overlap(block,other) is not None])
|
29 |
+
world.append(placeAtHeight(block, lowest))
|
30 |
+
|
31 |
+
w = []
|
32 |
+
for p in plan: placeBlock(w,p)
|
33 |
+
if ordered: w = list(sorted(w))
|
34 |
+
return w
|
35 |
+
|
36 |
+
def centerTower(t,hand=None, masterPlan=None):
|
37 |
+
|
38 |
+
if len(t) == 0:
|
39 |
+
if hand is None:
|
40 |
+
return t
|
41 |
+
else:
|
42 |
+
return t, hand
|
43 |
+
def getCenter(t):
|
44 |
+
x1 = max(x for x, _, _ in t)
|
45 |
+
x0 = min(x for x, _, _ in t)
|
46 |
+
c = int((x1 - x0) / 2.0) + x0
|
47 |
+
return c
|
48 |
+
c = getCenter(masterPlan or t)
|
49 |
+
t = [(x - c, w, h) for x, w, h in t]
|
50 |
+
if hand is None:
|
51 |
+
return t
|
52 |
+
else:
|
53 |
+
return t, hand - c
|
54 |
+
|
55 |
+
def towerLength(t):
|
56 |
+
if len(t) == 0: return 0
|
57 |
+
x1 = max(x for x, _, _ in t)
|
58 |
+
x0 = min(x for x, _, _ in t)
|
59 |
+
return x1 - x0
|
60 |
+
|
61 |
+
def towerHeight(t):
|
62 |
+
y1 = max(y + h/2 for _, y, _, h in t )
|
63 |
+
y0 = min(y - h/2 for _, y, _, h in t )
|
64 |
+
return y1 - y0
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
def renderPlan(plan, resolution=256, window=64, floorHeight=2, borderSize=1, bodyColor=(0.,1.,1.),
|
69 |
+
borderColor=(1.,0.,0.),
|
70 |
+
truncate=None, randomSeed=None,
|
71 |
+
masterPlan=None,
|
72 |
+
pretty=False, Lego=False,
|
73 |
+
drawHand=None):
|
74 |
+
import numpy as np
|
75 |
+
|
76 |
+
if Lego: assert pretty
|
77 |
+
|
78 |
+
if drawHand is not None and drawHand is not False:
|
79 |
+
plan, drawHand = centerTower(plan, drawHand,
|
80 |
+
masterPlan=masterPlan)
|
81 |
+
else:
|
82 |
+
plan = centerTower(plan,masterPlan=masterPlan)
|
83 |
+
|
84 |
+
world = simulateWithoutPhysics(plan,
|
85 |
+
ordered=randomSeed is None)
|
86 |
+
if truncate is not None: world = world[:truncate]
|
87 |
+
a = np.zeros((resolution, resolution, 3))
|
88 |
+
|
89 |
+
def transform(x,y):
|
90 |
+
y = resolution - y*resolution/float(window)
|
91 |
+
x = resolution/2 + x*resolution/float(window)
|
92 |
+
return int(x + 0.5),int(y + 0.5)
|
93 |
+
def clip(p):
|
94 |
+
if p < 0: return 0
|
95 |
+
if p >= resolution: return resolution - 1
|
96 |
+
return int(p + 0.5)
|
97 |
+
def clear(x,y):
|
98 |
+
for xp,yp,wp,hp in world:
|
99 |
+
if x < xp + wp/2. and \
|
100 |
+
x > xp - wp/2. and \
|
101 |
+
y < yp + hp/2. and \
|
102 |
+
y > yp - hp/2.:
|
103 |
+
return False
|
104 |
+
return True
|
105 |
+
def bump(x,y,c):
|
106 |
+
size = 0.5*resolution/window
|
107 |
+
x,y = transform(x,y)
|
108 |
+
y -= floorHeight
|
109 |
+
y1 = y
|
110 |
+
y2 = y - size
|
111 |
+
x1 = x - size/2
|
112 |
+
x2 = x + size/2
|
113 |
+
a[clip(y2) : clip(y1),
|
114 |
+
clip(x1) : clip(x2),
|
115 |
+
:] = c
|
116 |
+
|
117 |
+
|
118 |
+
if randomSeed is not None:
|
119 |
+
randomNumbers = random.Random(randomSeed)
|
120 |
+
def _color():
|
121 |
+
if randomSeed is None:
|
122 |
+
return random.random()*0.7 + 0.3
|
123 |
+
else:
|
124 |
+
return randomNumbers.random()*0.7 + 0.3
|
125 |
+
def color():
|
126 |
+
return (_color(),_color(),_color())
|
127 |
+
|
128 |
+
def rectangle(x1,x2,y1,y2,c,cp=None):
|
129 |
+
x1,y1 = transform(x1,y1)
|
130 |
+
x2,y2 = transform(x2,y2)
|
131 |
+
y1 -= floorHeight
|
132 |
+
y2 -= floorHeight
|
133 |
+
a[clip(y2) : clip(y1),
|
134 |
+
clip(x1) : clip(x2),
|
135 |
+
:] = c
|
136 |
+
if cp is not None:
|
137 |
+
a[clip(y2 + borderSize) : clip(y1 - borderSize),
|
138 |
+
clip(x1 + borderSize) : clip(x2 - borderSize),
|
139 |
+
:] = cp
|
140 |
+
|
141 |
+
for x,y,w,h in world:
|
142 |
+
x1,y1 = x - w/2., y - h/2.
|
143 |
+
x2,y2 = x + w/2., y + h/2.
|
144 |
+
if pretty:
|
145 |
+
thisColor = color()
|
146 |
+
rectangle(x1,x2,y1,y2,
|
147 |
+
thisColor)
|
148 |
+
if Lego:
|
149 |
+
bumps = w
|
150 |
+
for nb in range(bumps):
|
151 |
+
nx = x - w/2. + 0.5 + nb
|
152 |
+
ny = y + h/2. + 0.00001
|
153 |
+
if clear(nx,ny):
|
154 |
+
bump(nx,ny,thisColor)
|
155 |
+
else:
|
156 |
+
rectangle(x1,x2,y1,y2,
|
157 |
+
borderColor, bodyColor)
|
158 |
+
|
159 |
+
a[resolution - floorHeight:,:,:] = 1.
|
160 |
+
if drawHand is not None:
|
161 |
+
if not Lego:
|
162 |
+
dh = 0.25
|
163 |
+
rectangle(drawHand - dh,
|
164 |
+
drawHand + dh,
|
165 |
+
-99999, 99999,
|
166 |
+
(0,1,0))
|
167 |
+
else:
|
168 |
+
rectangle(drawHand - 1,drawHand + 1,
|
169 |
+
43,45,(1,1,1))
|
170 |
+
|
171 |
+
return a
|
172 |
+
|
173 |
+
|
dreamcoder/dreamcoder.py
ADDED
@@ -0,0 +1,1074 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
|
3 |
+
import dill
|
4 |
+
|
5 |
+
from dreamcoder.compression import induceGrammar
|
6 |
+
from dreamcoder.utilities import *
|
7 |
+
try:
|
8 |
+
from dreamcoder.recognition import *
|
9 |
+
except:
|
10 |
+
eprint("Failure loading recognition - only acceptable if using pypy ")
|
11 |
+
from dreamcoder.enumeration import *
|
12 |
+
from dreamcoder.fragmentGrammar import *
|
13 |
+
from dreamcoder.taskBatcher import *
|
14 |
+
from dreamcoder.primitiveGraph import graphPrimitives
|
15 |
+
from dreamcoder.dreaming import backgroundHelmholtzEnumeration
|
16 |
+
|
17 |
+
|
18 |
+
class ECResult():
|
19 |
+
def __init__(self, _=None,
|
20 |
+
frontiersOverTime=None,
|
21 |
+
testingSearchTime=None,
|
22 |
+
learningCurve=None,
|
23 |
+
grammars=None,
|
24 |
+
taskSolutions=None,
|
25 |
+
averageDescriptionLength=None,
|
26 |
+
parameters=None,
|
27 |
+
recognitionModel=None,
|
28 |
+
searchTimes=None,
|
29 |
+
recognitionTaskMetrics=None,
|
30 |
+
numTestingTasks=None,
|
31 |
+
sumMaxll=None,
|
32 |
+
testingSumMaxll=None,
|
33 |
+
hitsAtEachWake=None,
|
34 |
+
timesAtEachWake=None,
|
35 |
+
allFrontiers=None):
|
36 |
+
self.frontiersOverTime = {} # Map from task to [frontier at iteration 1, frontier at iteration 2, ...]
|
37 |
+
self.hitsAtEachWake = hitsAtEachWake or []
|
38 |
+
self.timesAtEachWake = timesAtEachWake or []
|
39 |
+
self.testingSearchTime = testingSearchTime or []
|
40 |
+
self.searchTimes = searchTimes or []
|
41 |
+
self.trainSearchTime = {} # map from task to search time
|
42 |
+
self.testSearchTime = {} # map from task to search time
|
43 |
+
self.recognitionTaskMetrics = recognitionTaskMetrics or {}
|
44 |
+
self.recognitionModel = recognitionModel
|
45 |
+
self.averageDescriptionLength = averageDescriptionLength or []
|
46 |
+
self.parameters = parameters
|
47 |
+
self.learningCurve = learningCurve or []
|
48 |
+
self.grammars = grammars or []
|
49 |
+
self.taskSolutions = taskSolutions or {}
|
50 |
+
self.numTestingTasks = numTestingTasks
|
51 |
+
self.sumMaxll = sumMaxll or [] #TODO name change
|
52 |
+
self.testingSumMaxll = testingSumMaxll or [] #TODO name change
|
53 |
+
self.allFrontiers = allFrontiers or {}
|
54 |
+
|
55 |
+
def __repr__(self):
|
56 |
+
attrs = ["{}={}".format(k, v) for k, v in self.__dict__.items()]
|
57 |
+
return "ECResult({})".format(", ".join(attrs))
|
58 |
+
|
59 |
+
def getTestingTasks(self):
|
60 |
+
testing = []
|
61 |
+
training = self.taskSolutions.keys()
|
62 |
+
for t in self.recognitionTaskMetrics:
|
63 |
+
if isinstance(t, Task) and t not in training: testing.append(t)
|
64 |
+
return testing
|
65 |
+
|
66 |
+
|
67 |
+
def recordFrontier(self, frontier):
|
68 |
+
t = frontier.task
|
69 |
+
if t not in self.frontiersOverTime: self.frontiersOverTime[t] = []
|
70 |
+
self.frontiersOverTime[t].append(frontier)
|
71 |
+
|
72 |
+
# Linux does not like files that have more than 256 characters
|
73 |
+
# So when exporting the results we abbreviate the parameters
|
74 |
+
abbreviations = {"frontierSize": "fs",
|
75 |
+
"useDSL": "DSL",
|
76 |
+
"taskReranker": "TRR",
|
77 |
+
"matrixRank": "MR",
|
78 |
+
"reuseRecognition": "RR",
|
79 |
+
"ensembleSize": "ES",
|
80 |
+
"recognitionTimeout": "RT",
|
81 |
+
"recognitionSteps": "RS",
|
82 |
+
"iterations": "it",
|
83 |
+
"maximumFrontier": "MF",
|
84 |
+
"pseudoCounts": "pc",
|
85 |
+
"auxiliaryLoss": "aux",
|
86 |
+
"structurePenalty": "L",
|
87 |
+
"helmholtzRatio": "HR",
|
88 |
+
"biasOptimal": "BO",
|
89 |
+
"contextual": "CO",
|
90 |
+
"topK": "K",
|
91 |
+
"enumerationTimeout": "ET",
|
92 |
+
"useRecognitionModel": "rec",
|
93 |
+
"use_ll_cutoff": "llcut",
|
94 |
+
"topk_use_only_likelihood": "topkNotMAP",
|
95 |
+
"activation": "act",
|
96 |
+
"storeTaskMetrics": 'STM',
|
97 |
+
"topkNotMAP": "tknm",
|
98 |
+
"rewriteTaskMetrics": "RW",
|
99 |
+
'taskBatchSize': 'batch'}
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def abbreviate(parameter): return ECResult.abbreviations.get(parameter, parameter)
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def parameterOfAbbreviation(abbreviation):
|
106 |
+
return ECResult.abbreviationToParameter.get(abbreviation, abbreviation)
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def clearRecognitionModel(path):
|
110 |
+
SUFFIX = '.pickle'
|
111 |
+
assert path.endswith(SUFFIX)
|
112 |
+
|
113 |
+
with open(path,'rb') as handle:
|
114 |
+
result = dill.load(handle)
|
115 |
+
|
116 |
+
result.recognitionModel = None
|
117 |
+
|
118 |
+
clearedPath = path[:-len(SUFFIX)] + "_graph=True" + SUFFIX
|
119 |
+
with open(clearedPath,'wb') as handle:
|
120 |
+
result = dill.dump(result, handle)
|
121 |
+
eprint(" [+] Cleared recognition model from:")
|
122 |
+
eprint(" %s"%path)
|
123 |
+
eprint(" and exported to:")
|
124 |
+
eprint(" %s"%clearedPath)
|
125 |
+
eprint(" Use this one for graphing.")
|
126 |
+
|
127 |
+
|
128 |
+
ECResult.abbreviationToParameter = {
|
129 |
+
v: k for k, v in ECResult.abbreviations.items()}
|
130 |
+
|
131 |
+
|
132 |
+
def explorationCompression(*arguments, **keywords):
|
133 |
+
for r in ecIterator(*arguments, **keywords):
|
134 |
+
pass
|
135 |
+
return r
|
136 |
+
|
137 |
+
|
138 |
+
def ecIterator(grammar, tasks,
|
139 |
+
_=None,
|
140 |
+
useDSL=True,
|
141 |
+
noConsolidation=False,
|
142 |
+
mask=False,
|
143 |
+
seed=0,
|
144 |
+
addFullTaskMetrics=False,
|
145 |
+
matrixRank=None,
|
146 |
+
solver='ocaml',
|
147 |
+
compressor="rust",
|
148 |
+
biasOptimal=False,
|
149 |
+
contextual=False,
|
150 |
+
testingTasks=[],
|
151 |
+
iterations=None,
|
152 |
+
resume=None,
|
153 |
+
enumerationTimeout=None,
|
154 |
+
testingTimeout=None,
|
155 |
+
testEvery=1,
|
156 |
+
reuseRecognition=False,
|
157 |
+
ensembleSize=1,
|
158 |
+
useRecognitionModel=True,
|
159 |
+
recognitionTimeout=None,
|
160 |
+
recognitionSteps=None,
|
161 |
+
helmholtzRatio=0.,
|
162 |
+
featureExtractor=None,
|
163 |
+
activation='relu',
|
164 |
+
topK=1,
|
165 |
+
topk_use_only_likelihood=False,
|
166 |
+
use_map_search_times=True,
|
167 |
+
maximumFrontier=None,
|
168 |
+
pseudoCounts=1.0, aic=1.0,
|
169 |
+
structurePenalty=0.001, arity=0,
|
170 |
+
evaluationTimeout=1.0, # seconds
|
171 |
+
taskBatchSize=None,
|
172 |
+
taskReranker='default',
|
173 |
+
CPUs=1,
|
174 |
+
cuda=False,
|
175 |
+
message="",
|
176 |
+
outputPrefix=None,
|
177 |
+
storeTaskMetrics=False,
|
178 |
+
rewriteTaskMetrics=True,
|
179 |
+
auxiliaryLoss=False,
|
180 |
+
custom_wake_generative=None):
|
181 |
+
if enumerationTimeout is None:
|
182 |
+
eprint(
|
183 |
+
"Please specify an enumeration timeout:",
|
184 |
+
"explorationCompression(..., enumerationTimeout = ..., ...)")
|
185 |
+
assert False
|
186 |
+
if iterations is None:
|
187 |
+
eprint(
|
188 |
+
"Please specify a iteration count: explorationCompression(..., iterations = ...)")
|
189 |
+
assert False
|
190 |
+
if useRecognitionModel and featureExtractor is None:
|
191 |
+
eprint("Warning: Recognition model needs feature extractor.",
|
192 |
+
"Ignoring recognition model.")
|
193 |
+
useRecognitionModel = False
|
194 |
+
if ensembleSize > 1 and not useRecognitionModel:
|
195 |
+
eprint("Warning: ensemble size requires using the recognition model, aborting.")
|
196 |
+
assert False
|
197 |
+
if biasOptimal and not useRecognitionModel:
|
198 |
+
eprint("Bias optimality only applies to recognition models, aborting.")
|
199 |
+
assert False
|
200 |
+
if contextual and not useRecognitionModel:
|
201 |
+
eprint("Contextual only applies to recognition models, aborting")
|
202 |
+
assert False
|
203 |
+
if reuseRecognition and not useRecognitionModel:
|
204 |
+
eprint("Reuse of recognition model weights at successive iteration only applies to recognition models, aborting")
|
205 |
+
assert False
|
206 |
+
if matrixRank is not None and not contextual:
|
207 |
+
eprint("Matrix rank only applies to contextual recognition models, aborting")
|
208 |
+
assert False
|
209 |
+
assert useDSL or useRecognitionModel, "You specified that you didn't want to use the DSL AND you don't want to use the recognition model. Figure out what you want to use."
|
210 |
+
if testingTimeout > 0 and len(testingTasks) == 0:
|
211 |
+
eprint("You specified a testingTimeout, but did not provide any held out testing tasks, aborting.")
|
212 |
+
assert False
|
213 |
+
|
214 |
+
# We save the parameters that were passed into EC
|
215 |
+
# This is for the purpose of exporting the results of the experiment
|
216 |
+
parameters = {
|
217 |
+
k: v for k,
|
218 |
+
v in locals().items() if k not in {
|
219 |
+
"tasks",
|
220 |
+
"use_map_search_times",
|
221 |
+
"seed",
|
222 |
+
"activation",
|
223 |
+
"grammar",
|
224 |
+
"cuda",
|
225 |
+
"_",
|
226 |
+
"testingTimeout",
|
227 |
+
"testEvery",
|
228 |
+
"message",
|
229 |
+
"CPUs",
|
230 |
+
"outputPrefix",
|
231 |
+
"resume",
|
232 |
+
"resumeFrontierSize",
|
233 |
+
"addFullTaskMetrics",
|
234 |
+
"featureExtractor",
|
235 |
+
"evaluationTimeout",
|
236 |
+
"testingTasks",
|
237 |
+
"compressor",
|
238 |
+
"custom_wake_generative"} and v is not None}
|
239 |
+
if not useRecognitionModel:
|
240 |
+
for k in {"helmholtzRatio", "recognitionTimeout", "biasOptimal", "mask",
|
241 |
+
"contextual", "matrixRank", "reuseRecognition", "auxiliaryLoss", "ensembleSize"}:
|
242 |
+
if k in parameters: del parameters[k]
|
243 |
+
else: del parameters["useRecognitionModel"];
|
244 |
+
if useRecognitionModel and not contextual:
|
245 |
+
if "matrixRank" in parameters:
|
246 |
+
del parameters["matrixRank"]
|
247 |
+
if "mask" in parameters:
|
248 |
+
del parameters["mask"]
|
249 |
+
if not mask and 'mask' in parameters: del parameters["mask"]
|
250 |
+
if not auxiliaryLoss and 'auxiliaryLoss' in parameters: del parameters['auxiliaryLoss']
|
251 |
+
if not useDSL:
|
252 |
+
for k in {"structurePenalty", "pseudoCounts", "aic"}:
|
253 |
+
del parameters[k]
|
254 |
+
else: del parameters["useDSL"]
|
255 |
+
|
256 |
+
# Uses `parameters` to construct the checkpoint path
|
257 |
+
def checkpointPath(iteration, extra=""):
|
258 |
+
parameters["iterations"] = iteration
|
259 |
+
kvs = [
|
260 |
+
"{}={}".format(
|
261 |
+
ECResult.abbreviate(k),
|
262 |
+
parameters[k]) for k in sorted(
|
263 |
+
parameters.keys())]
|
264 |
+
return "{}_{}{}.pickle".format(outputPrefix, "_".join(kvs), extra)
|
265 |
+
|
266 |
+
if message:
|
267 |
+
message = " (" + message + ")"
|
268 |
+
eprint("Running EC%s on %s @ %s with %d CPUs and parameters:" %
|
269 |
+
(message, os.uname()[1], datetime.datetime.now(), CPUs))
|
270 |
+
for k, v in parameters.items():
|
271 |
+
eprint("\t", k, " = ", v)
|
272 |
+
eprint("\t", "evaluationTimeout", " = ", evaluationTimeout)
|
273 |
+
eprint("\t", "cuda", " = ", cuda)
|
274 |
+
eprint()
|
275 |
+
|
276 |
+
if addFullTaskMetrics:
|
277 |
+
assert resume is not None, "--addFullTaskMetrics requires --resume"
|
278 |
+
|
279 |
+
def reportMemory():
|
280 |
+
eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
|
281 |
+
|
282 |
+
# Restore checkpoint
|
283 |
+
if resume is not None:
|
284 |
+
try:
|
285 |
+
resume = int(resume)
|
286 |
+
path = checkpointPath(resume)
|
287 |
+
except ValueError:
|
288 |
+
path = resume
|
289 |
+
with open(path, "rb") as handle:
|
290 |
+
result = dill.load(handle)
|
291 |
+
resume = len(result.grammars) - 1
|
292 |
+
eprint("Loaded checkpoint from", path)
|
293 |
+
grammar = result.grammars[-1] if result.grammars else grammar
|
294 |
+
else: # Start from scratch
|
295 |
+
#for graphing of testing tasks
|
296 |
+
numTestingTasks = len(testingTasks) if len(testingTasks) != 0 else None
|
297 |
+
|
298 |
+
result = ECResult(parameters=parameters,
|
299 |
+
grammars=[grammar],
|
300 |
+
taskSolutions={
|
301 |
+
t: Frontier([],
|
302 |
+
task=t) for t in tasks},
|
303 |
+
recognitionModel=None, numTestingTasks=numTestingTasks,
|
304 |
+
allFrontiers={
|
305 |
+
t: Frontier([],
|
306 |
+
task=t) for t in tasks})
|
307 |
+
|
308 |
+
|
309 |
+
# Set up the task batcher.
|
310 |
+
if taskReranker == 'default':
|
311 |
+
taskBatcher = DefaultTaskBatcher()
|
312 |
+
elif taskReranker == 'random':
|
313 |
+
taskBatcher = RandomTaskBatcher()
|
314 |
+
elif taskReranker == 'randomShuffle':
|
315 |
+
taskBatcher = RandomShuffleTaskBatcher(seed)
|
316 |
+
elif taskReranker == 'unsolved':
|
317 |
+
taskBatcher = UnsolvedTaskBatcher()
|
318 |
+
elif taskReranker == 'unsolvedEntropy':
|
319 |
+
taskBatcher = UnsolvedEntropyTaskBatcher()
|
320 |
+
elif taskReranker == 'unsolvedRandomEntropy':
|
321 |
+
taskBatcher = UnsolvedRandomEntropyTaskBatcher()
|
322 |
+
elif taskReranker == 'randomkNN':
|
323 |
+
taskBatcher = RandomkNNTaskBatcher()
|
324 |
+
elif taskReranker == 'randomLowEntropykNN':
|
325 |
+
taskBatcher = RandomLowEntropykNNTaskBatcher()
|
326 |
+
else:
|
327 |
+
eprint("Invalid task reranker: " + taskReranker + ", aborting.")
|
328 |
+
assert False
|
329 |
+
|
330 |
+
# Check if we are just updating the full task metrics
|
331 |
+
if addFullTaskMetrics:
|
332 |
+
if testingTimeout is not None and testingTimeout > enumerationTimeout:
|
333 |
+
enumerationTimeout = testingTimeout
|
334 |
+
if result.recognitionModel is not None:
|
335 |
+
_enumerator = lambda *args, **kw: result.recognitionModel.enumerateFrontiers(*args, **kw)
|
336 |
+
else: _enumerator = lambda *args, **kw: multicoreEnumeration(result.grammars[-1], *args, **kw)
|
337 |
+
enumerator = lambda *args, **kw: _enumerator(*args,
|
338 |
+
maximumFrontier=maximumFrontier,
|
339 |
+
CPUs=CPUs, evaluationTimeout=evaluationTimeout,
|
340 |
+
solver=solver,
|
341 |
+
**kw)
|
342 |
+
trainFrontiers, _, trainingTimes = enumerator(tasks, enumerationTimeout=enumerationTimeout)
|
343 |
+
testFrontiers, _, testingTimes = enumerator(testingTasks, enumerationTimeout=testingTimeout, testing=True)
|
344 |
+
|
345 |
+
recognizer = result.recognitionModel
|
346 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, trainingTimes, 'recognitionBestTimes')
|
347 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarLogProductions(tasks), 'taskLogProductions')
|
348 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarEntropies(tasks), 'taskGrammarEntropies')
|
349 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskAuxiliaryLossLayer(tasks), 'taskAuxiliaryLossLayer')
|
350 |
+
|
351 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, testingTimes, 'heldoutTestingTimes')
|
352 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarLogProductions(testingTasks), 'heldoutTaskLogProductions')
|
353 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarEntropies(testingTasks), 'heldoutTaskGrammarEntropies')
|
354 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskAuxiliaryLossLayer(testingTasks), 'heldoutAuxiliaryLossLayer')
|
355 |
+
|
356 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, {f.task: f
|
357 |
+
for f in trainFrontiers + testFrontiers
|
358 |
+
if len(f) > 0},
|
359 |
+
'frontier')
|
360 |
+
SUFFIX = ".pickle"
|
361 |
+
assert path.endswith(SUFFIX)
|
362 |
+
path = path[:-len(SUFFIX)] + "_FTM=True" + SUFFIX
|
363 |
+
with open(path, "wb") as handle: dill.dump(result, handle)
|
364 |
+
if useRecognitionModel: ECResult.clearRecognitionModel(path)
|
365 |
+
|
366 |
+
sys.exit(0)
|
367 |
+
|
368 |
+
|
369 |
+
for j in range(resume or 0, iterations):
|
370 |
+
if storeTaskMetrics and rewriteTaskMetrics:
|
371 |
+
eprint("Resetting task metrics for next iteration.")
|
372 |
+
result.recognitionTaskMetrics = {}
|
373 |
+
|
374 |
+
reportMemory()
|
375 |
+
|
376 |
+
# Evaluate on held out tasks if we have them
|
377 |
+
if testingTimeout > 0 and ((j % testEvery == 0) or (j == iterations - 1)):
|
378 |
+
eprint("Evaluating on held out testing tasks for iteration: %d" % (j))
|
379 |
+
evaluateOnTestingTasks(result, testingTasks, grammar,
|
380 |
+
CPUs=CPUs, maximumFrontier=maximumFrontier,
|
381 |
+
solver=solver,
|
382 |
+
enumerationTimeout=testingTimeout, evaluationTimeout=evaluationTimeout)
|
383 |
+
# If we have to also enumerate Helmholtz frontiers,
|
384 |
+
# do this extra sneaky in the background
|
385 |
+
if useRecognitionModel and biasOptimal and helmholtzRatio > 0 and \
|
386 |
+
all( str(p) != "REAL" for p in grammar.primitives ): # real numbers don't support this
|
387 |
+
# the DSL is fixed, so the dreams are also fixed. don't recompute them.
|
388 |
+
if useDSL or 'helmholtzFrontiers' not in locals():
|
389 |
+
helmholtzFrontiers = backgroundHelmholtzEnumeration(tasks, grammar, enumerationTimeout,
|
390 |
+
evaluationTimeout=evaluationTimeout,
|
391 |
+
special=featureExtractor.special)
|
392 |
+
else:
|
393 |
+
print("Reusing dreams from previous iteration.")
|
394 |
+
else:
|
395 |
+
helmholtzFrontiers = lambda: []
|
396 |
+
|
397 |
+
reportMemory()
|
398 |
+
|
399 |
+
# Get waking task batch.
|
400 |
+
wakingTaskBatch = taskBatcher.getTaskBatch(result, tasks, taskBatchSize, j)
|
401 |
+
eprint("Using a waking task batch of size: " + str(len(wakingTaskBatch)))
|
402 |
+
|
403 |
+
# WAKING UP
|
404 |
+
if useDSL:
|
405 |
+
wake_generative = custom_wake_generative if custom_wake_generative is not None else default_wake_generative
|
406 |
+
topDownFrontiers, times = wake_generative(grammar, wakingTaskBatch,
|
407 |
+
solver=solver,
|
408 |
+
maximumFrontier=maximumFrontier,
|
409 |
+
enumerationTimeout=enumerationTimeout,
|
410 |
+
CPUs=CPUs,
|
411 |
+
evaluationTimeout=evaluationTimeout)
|
412 |
+
result.trainSearchTime = {t: tm for t, tm in times.items() if tm is not None}
|
413 |
+
else:
|
414 |
+
eprint("Skipping top-down enumeration because we are not using the generative model")
|
415 |
+
topDownFrontiers, times = [], {t: None for t in wakingTaskBatch }
|
416 |
+
|
417 |
+
tasksHitTopDown = {f.task for f in topDownFrontiers if not f.empty}
|
418 |
+
result.hitsAtEachWake.append(len(tasksHitTopDown))
|
419 |
+
|
420 |
+
reportMemory()
|
421 |
+
|
422 |
+
# Combine topDownFrontiers from this task batch with all frontiers.
|
423 |
+
for f in topDownFrontiers:
|
424 |
+
if f.task not in result.allFrontiers: continue # backward compatibility with old checkpoints
|
425 |
+
result.allFrontiers[f.task] = result.allFrontiers[f.task].combine(f).topK(maximumFrontier)
|
426 |
+
|
427 |
+
eprint("Frontiers discovered top down: " + str(len(tasksHitTopDown)))
|
428 |
+
eprint("Total frontiers: " + str(len([f for f in result.allFrontiers.values() if not f.empty])))
|
429 |
+
|
430 |
+
# Train + use recognition model
|
431 |
+
if useRecognitionModel:
|
432 |
+
# Should we initialize the weights to be what they were before?
|
433 |
+
previousRecognitionModel = None
|
434 |
+
if reuseRecognition and result.recognitionModel is not None:
|
435 |
+
previousRecognitionModel = result.recognitionModel
|
436 |
+
|
437 |
+
thisRatio = helmholtzRatio
|
438 |
+
#if j == 0 and not biasOptimal: thisRatio = 0
|
439 |
+
if all( f.empty for f in result.allFrontiers.values() ): thisRatio = 1.
|
440 |
+
|
441 |
+
tasksHitBottomUp = \
|
442 |
+
sleep_recognition(result, grammar, wakingTaskBatch, tasks, testingTasks, result.allFrontiers.values(),
|
443 |
+
ensembleSize=ensembleSize, featureExtractor=featureExtractor, mask=mask,
|
444 |
+
activation=activation, contextual=contextual, biasOptimal=biasOptimal,
|
445 |
+
previousRecognitionModel=previousRecognitionModel, matrixRank=matrixRank,
|
446 |
+
timeout=recognitionTimeout, evaluationTimeout=evaluationTimeout,
|
447 |
+
enumerationTimeout=enumerationTimeout,
|
448 |
+
helmholtzRatio=thisRatio, helmholtzFrontiers=helmholtzFrontiers(),
|
449 |
+
auxiliaryLoss=auxiliaryLoss, cuda=cuda, CPUs=CPUs, solver=solver,
|
450 |
+
recognitionSteps=recognitionSteps, maximumFrontier=maximumFrontier)
|
451 |
+
|
452 |
+
showHitMatrix(tasksHitTopDown, tasksHitBottomUp, wakingTaskBatch)
|
453 |
+
|
454 |
+
# Record the new topK solutions
|
455 |
+
result.taskSolutions = {f.task: f.topK(topK)
|
456 |
+
for f in result.allFrontiers.values()}
|
457 |
+
for f in result.allFrontiers.values(): result.recordFrontier(f)
|
458 |
+
result.learningCurve += [
|
459 |
+
sum(f is not None and not f.empty for f in result.taskSolutions.values())]
|
460 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, {f.task: f
|
461 |
+
for f in result.allFrontiers.values()
|
462 |
+
if len(f) > 0},
|
463 |
+
'frontier')
|
464 |
+
|
465 |
+
# Sleep-G
|
466 |
+
if useDSL and not(noConsolidation):
|
467 |
+
eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
|
468 |
+
grammar = consolidate(result, grammar, topK=topK, pseudoCounts=pseudoCounts, arity=arity, aic=aic,
|
469 |
+
structurePenalty=structurePenalty, compressor=compressor, CPUs=CPUs,
|
470 |
+
iteration=j)
|
471 |
+
eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
|
472 |
+
else:
|
473 |
+
eprint("Skipping consolidation.")
|
474 |
+
result.grammars.append(grammar)
|
475 |
+
|
476 |
+
if outputPrefix is not None:
|
477 |
+
path = checkpointPath(j + 1)
|
478 |
+
with open(path, "wb") as handle:
|
479 |
+
try:
|
480 |
+
dill.dump(result, handle)
|
481 |
+
except TypeError as e:
|
482 |
+
eprint(result)
|
483 |
+
assert(False)
|
484 |
+
eprint("Exported checkpoint to", path)
|
485 |
+
if useRecognitionModel:
|
486 |
+
ECResult.clearRecognitionModel(path)
|
487 |
+
|
488 |
+
graphPrimitives(result, "%s_primitives_%d_"%(outputPrefix,j))
|
489 |
+
|
490 |
+
|
491 |
+
yield result
|
492 |
+
|
493 |
+
|
494 |
+
def showHitMatrix(top, bottom, tasks):
|
495 |
+
tasks = set(tasks)
|
496 |
+
|
497 |
+
total = bottom | top
|
498 |
+
eprint(len(total), "/", len(tasks), "total hit tasks")
|
499 |
+
bottomMiss = tasks - bottom
|
500 |
+
topMiss = tasks - top
|
501 |
+
|
502 |
+
eprint("{: <13s}{: ^13s}{: ^13s}".format("", "bottom miss", "bottom hit"))
|
503 |
+
eprint("{: <13s}{: ^13d}{: ^13d}".format("top miss",
|
504 |
+
len(bottomMiss & topMiss),
|
505 |
+
len(bottom & topMiss)))
|
506 |
+
eprint("{: <13s}{: ^13d}{: ^13d}".format("top hit",
|
507 |
+
len(top & bottomMiss),
|
508 |
+
len(top & bottom)))
|
509 |
+
|
510 |
+
def evaluateOnTestingTasks(result, testingTasks, grammar, _=None,
|
511 |
+
CPUs=None, solver=None, maximumFrontier=None, enumerationTimeout=None, evaluationTimeout=None):
|
512 |
+
if result.recognitionModel is not None:
|
513 |
+
recognizer = result.recognitionModel
|
514 |
+
testingFrontiers, times = \
|
515 |
+
recognizer.enumerateFrontiers(testingTasks,
|
516 |
+
CPUs=CPUs,
|
517 |
+
solver=solver,
|
518 |
+
maximumFrontier=maximumFrontier,
|
519 |
+
enumerationTimeout=enumerationTimeout,
|
520 |
+
evaluationTimeout=evaluationTimeout,
|
521 |
+
testing=True)
|
522 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarLogProductions(testingTasks), 'heldoutTaskLogProductions')
|
523 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarEntropies(testingTasks), 'heldoutTaskGrammarEntropies')
|
524 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarEntropies(testingTasks), 'heldoutTaskGrammarEntropies')
|
525 |
+
else:
|
526 |
+
testingFrontiers, times = multicoreEnumeration(grammar, testingTasks,
|
527 |
+
solver=solver,
|
528 |
+
maximumFrontier=maximumFrontier,
|
529 |
+
enumerationTimeout=enumerationTimeout,
|
530 |
+
CPUs=CPUs,
|
531 |
+
evaluationTimeout=evaluationTimeout,
|
532 |
+
testing=True)
|
533 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, times, 'heldoutTestingTimes')
|
534 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics,
|
535 |
+
{f.task: f for f in testingFrontiers if len(f) > 0 },
|
536 |
+
'frontier')
|
537 |
+
for f in testingFrontiers: result.recordFrontier(f)
|
538 |
+
result.testSearchTime = {t: tm for t, tm in times.items() if tm is not None}
|
539 |
+
times = [t for t in times.values() if t is not None ]
|
540 |
+
eprint("\n".join(f.summarize() for f in testingFrontiers))
|
541 |
+
summaryStatistics("Testing tasks", times)
|
542 |
+
eprint("Hits %d/%d testing tasks" % (len(times), len(testingTasks)))
|
543 |
+
result.testingSearchTime.append(times)
|
544 |
+
|
545 |
+
|
546 |
+
def default_wake_generative(grammar, tasks,
|
547 |
+
maximumFrontier=None,
|
548 |
+
enumerationTimeout=None,
|
549 |
+
CPUs=None,
|
550 |
+
solver=None,
|
551 |
+
evaluationTimeout=None):
|
552 |
+
topDownFrontiers, times = multicoreEnumeration(grammar, tasks,
|
553 |
+
maximumFrontier=maximumFrontier,
|
554 |
+
enumerationTimeout=enumerationTimeout,
|
555 |
+
CPUs=CPUs,
|
556 |
+
solver=solver,
|
557 |
+
evaluationTimeout=evaluationTimeout)
|
558 |
+
eprint("Generative model enumeration results:")
|
559 |
+
eprint(Frontier.describe(topDownFrontiers))
|
560 |
+
summaryStatistics("Generative model", [t for t in times.values() if t is not None])
|
561 |
+
return topDownFrontiers, times
|
562 |
+
|
563 |
+
def sleep_recognition(result, grammar, taskBatch, tasks, testingTasks, allFrontiers, _=None,
|
564 |
+
ensembleSize=1, featureExtractor=None, matrixRank=None, mask=False,
|
565 |
+
activation=None, contextual=True, biasOptimal=True,
|
566 |
+
previousRecognitionModel=None, recognitionSteps=None,
|
567 |
+
timeout=None, enumerationTimeout=None, evaluationTimeout=None,
|
568 |
+
helmholtzRatio=None, helmholtzFrontiers=None, maximumFrontier=None,
|
569 |
+
auxiliaryLoss=None, cuda=None, CPUs=None, solver=None):
|
570 |
+
eprint("Using an ensemble size of %d. Note that we will only store and test on the best recognition model." % ensembleSize)
|
571 |
+
|
572 |
+
featureExtractorObjects = [featureExtractor(tasks, testingTasks=testingTasks, cuda=cuda) for i in range(ensembleSize)]
|
573 |
+
recognizers = [RecognitionModel(featureExtractorObjects[i],
|
574 |
+
grammar,
|
575 |
+
mask=mask,
|
576 |
+
rank=matrixRank,
|
577 |
+
activation=activation,
|
578 |
+
cuda=cuda,
|
579 |
+
contextual=contextual,
|
580 |
+
previousRecognitionModel=previousRecognitionModel,
|
581 |
+
id=i) for i in range(ensembleSize)]
|
582 |
+
eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
|
583 |
+
trainedRecognizers = parallelMap(min(CPUs,len(recognizers)),
|
584 |
+
lambda recognizer: recognizer.train(allFrontiers,
|
585 |
+
biasOptimal=biasOptimal,
|
586 |
+
helmholtzFrontiers=helmholtzFrontiers,
|
587 |
+
CPUs=CPUs,
|
588 |
+
evaluationTimeout=evaluationTimeout,
|
589 |
+
timeout=timeout,
|
590 |
+
steps=recognitionSteps,
|
591 |
+
helmholtzRatio=helmholtzRatio,
|
592 |
+
auxLoss=auxiliaryLoss,
|
593 |
+
vectorized=True),
|
594 |
+
recognizers,
|
595 |
+
seedRandom=True)
|
596 |
+
eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
|
597 |
+
# Enumerate frontiers for each of the recognizers.
|
598 |
+
eprint("Trained an ensemble of %d recognition models, now enumerating." % len(trainedRecognizers))
|
599 |
+
ensembleFrontiers, ensembleTimes, ensembleRecognitionTimes = [], [], []
|
600 |
+
mostTasks = 0
|
601 |
+
bestRecognizer = None
|
602 |
+
totalTasksHitBottomUp = set()
|
603 |
+
for recIndex, recognizer in enumerate(trainedRecognizers):
|
604 |
+
eprint("Enumerating from recognizer %d of %d" % (recIndex, len(trainedRecognizers)))
|
605 |
+
bottomupFrontiers, allRecognitionTimes = \
|
606 |
+
recognizer.enumerateFrontiers(taskBatch,
|
607 |
+
CPUs=CPUs,
|
608 |
+
maximumFrontier=maximumFrontier,
|
609 |
+
enumerationTimeout=enumerationTimeout,
|
610 |
+
evaluationTimeout=evaluationTimeout,
|
611 |
+
solver=solver)
|
612 |
+
ensembleFrontiers.append(bottomupFrontiers)
|
613 |
+
ensembleTimes.append([t for t in allRecognitionTimes.values() if t is not None])
|
614 |
+
ensembleRecognitionTimes.append(allRecognitionTimes)
|
615 |
+
|
616 |
+
recognizerTasksHitBottomUp = {f.task for f in bottomupFrontiers if not f.empty}
|
617 |
+
totalTasksHitBottomUp.update(recognizerTasksHitBottomUp)
|
618 |
+
eprint("Recognizer %d solved %d/%d tasks; total tasks solved is now %d." % (recIndex, len(recognizerTasksHitBottomUp), len(tasks), len(totalTasksHitBottomUp)))
|
619 |
+
if len(recognizerTasksHitBottomUp) >= mostTasks:
|
620 |
+
# TODO (cathywong): could consider keeping the one that put the highest likelihood on the solved tasks.
|
621 |
+
bestRecognizer = recIndex
|
622 |
+
|
623 |
+
# Store the recognizer that discovers the most frontiers in the result.
|
624 |
+
eprint("Best recognizer: %d." % bestRecognizer)
|
625 |
+
result.recognitionModel = trainedRecognizers[bestRecognizer]
|
626 |
+
result.trainSearchTime = {tk: tm for tk, tm in ensembleRecognitionTimes[bestRecognizer].items()
|
627 |
+
if tm is not None}
|
628 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, ensembleRecognitionTimes[bestRecognizer], 'recognitionBestTimes')
|
629 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskHiddenStates(tasks), 'hiddenState')
|
630 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarLogProductions(tasks), 'taskLogProductions')
|
631 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarEntropies(tasks), 'taskGrammarEntropies')
|
632 |
+
if contextual:
|
633 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics,
|
634 |
+
result.recognitionModel.taskGrammarStartProductions(tasks),
|
635 |
+
'startProductions')
|
636 |
+
|
637 |
+
result.hitsAtEachWake.append(len(totalTasksHitBottomUp))
|
638 |
+
eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
|
639 |
+
|
640 |
+
""" Rescore and combine the frontiers across the ensemble of recognition models."""
|
641 |
+
eprint("Recognition model enumeration results for the best recognizer.")
|
642 |
+
eprint(Frontier.describe(ensembleFrontiers[bestRecognizer]))
|
643 |
+
summaryStatistics("Recognition model", ensembleTimes[bestRecognizer])
|
644 |
+
|
645 |
+
eprint("Cumulative results for the full ensemble of %d recognizers: " % len(trainedRecognizers))
|
646 |
+
# Rescore all of the ensemble frontiers according to the generative model
|
647 |
+
# and then combine w/ original frontiers
|
648 |
+
for bottomupFrontiers in ensembleFrontiers:
|
649 |
+
for b in bottomupFrontiers:
|
650 |
+
if b.task not in result.allFrontiers: continue # backwards compatibility with old checkpoints
|
651 |
+
result.allFrontiers[b.task] = result.allFrontiers[b.task].\
|
652 |
+
combine(grammar.rescoreFrontier(b)).\
|
653 |
+
topK(maximumFrontier)
|
654 |
+
|
655 |
+
eprint("Frontiers discovered bottom up: " + str(len(totalTasksHitBottomUp)))
|
656 |
+
eprint("Total frontiers: " + str(len([f for f in result.allFrontiers.values() if not f.empty])))
|
657 |
+
|
658 |
+
result.searchTimes.append(ensembleTimes[bestRecognizer])
|
659 |
+
if len(ensembleTimes[bestRecognizer]) > 0:
|
660 |
+
eprint("Average search time: ", int(mean(ensembleTimes[bestRecognizer]) + 0.5),
|
661 |
+
"sec.\tmedian:", int(median(ensembleTimes[bestRecognizer]) + 0.5),
|
662 |
+
"\tmax:", int(max(ensembleTimes[bestRecognizer]) + 0.5),
|
663 |
+
"\tstandard deviation", int(standardDeviation(ensembleTimes[bestRecognizer]) + 0.5))
|
664 |
+
return totalTasksHitBottomUp
|
665 |
+
|
666 |
+
def consolidate(result, grammar, _=None, topK=None, arity=None, pseudoCounts=None, aic=None,
|
667 |
+
structurePenalty=None, compressor=None, CPUs=None, iteration=None):
|
668 |
+
eprint("Showing the top 5 programs in each frontier being sent to the compressor:")
|
669 |
+
for f in result.allFrontiers.values():
|
670 |
+
if f.empty:
|
671 |
+
continue
|
672 |
+
eprint(f.task)
|
673 |
+
for e in f.normalize().topK(5):
|
674 |
+
eprint("%.02f\t%s" % (e.logPosterior, e.program))
|
675 |
+
eprint()
|
676 |
+
|
677 |
+
# First check if we have supervision at the program level for any task that was not solved
|
678 |
+
needToSupervise = {f.task for f in result.allFrontiers.values()
|
679 |
+
if f.task.supervision is not None and f.empty}
|
680 |
+
compressionFrontiers = [f.replaceWithSupervised(grammar) if f.task in needToSupervise else f
|
681 |
+
for f in result.allFrontiers.values() ]
|
682 |
+
|
683 |
+
if len([f for f in compressionFrontiers if not f.empty]) == 0:
|
684 |
+
eprint("No compression frontiers; not inducing a grammar this iteration.")
|
685 |
+
else:
|
686 |
+
grammar, compressionFrontiers = induceGrammar(grammar, compressionFrontiers,
|
687 |
+
topK=topK,
|
688 |
+
pseudoCounts=pseudoCounts, a=arity,
|
689 |
+
aic=aic, structurePenalty=structurePenalty,
|
690 |
+
topk_use_only_likelihood=False,
|
691 |
+
backend=compressor, CPUs=CPUs, iteration=iteration)
|
692 |
+
# Store compression frontiers in the result.
|
693 |
+
for c in compressionFrontiers:
|
694 |
+
result.allFrontiers[c.task] = c.topK(0) if c in needToSupervise else c
|
695 |
+
|
696 |
+
|
697 |
+
result.grammars.append(grammar)
|
698 |
+
eprint("Grammar after iteration %d:" % (iteration + 1))
|
699 |
+
eprint(grammar)
|
700 |
+
|
701 |
+
return grammar
|
702 |
+
|
703 |
+
|
704 |
+
|
705 |
+
def commandlineArguments(_=None,
|
706 |
+
iterations=None,
|
707 |
+
enumerationTimeout=None,
|
708 |
+
testEvery=1,
|
709 |
+
topK=1,
|
710 |
+
reuseRecognition=False,
|
711 |
+
CPUs=1,
|
712 |
+
solver='ocaml',
|
713 |
+
compressor="ocaml",
|
714 |
+
useRecognitionModel=True,
|
715 |
+
recognitionTimeout=None,
|
716 |
+
activation='relu',
|
717 |
+
helmholtzRatio=1.,
|
718 |
+
featureExtractor=None,
|
719 |
+
cuda=None,
|
720 |
+
maximumFrontier=None,
|
721 |
+
pseudoCounts=1.0, aic=1.0,
|
722 |
+
structurePenalty=0.001, a=0,
|
723 |
+
taskBatchSize=None, taskReranker="default",
|
724 |
+
extras=None,
|
725 |
+
storeTaskMetrics=False,
|
726 |
+
rewriteTaskMetrics=True):
|
727 |
+
if cuda is None:
|
728 |
+
cuda = torch.cuda.is_available()
|
729 |
+
print("CUDA is available?:", torch.cuda.is_available())
|
730 |
+
print("using cuda?:", cuda)
|
731 |
+
import argparse
|
732 |
+
parser = argparse.ArgumentParser(description="")
|
733 |
+
parser.add_argument("--resume",
|
734 |
+
help="Resumes EC algorithm from checkpoint. You can either pass in the path of a checkpoint, or you can pass in the iteration to resume from, in which case it will try to figure out the path.",
|
735 |
+
default=None,
|
736 |
+
type=str)
|
737 |
+
parser.add_argument("-i", "--iterations",
|
738 |
+
help="default: %d" % iterations,
|
739 |
+
default=iterations,
|
740 |
+
type=int)
|
741 |
+
parser.add_argument("-t", "--enumerationTimeout",
|
742 |
+
default=enumerationTimeout,
|
743 |
+
help="In seconds. default: %s" % enumerationTimeout,
|
744 |
+
type=int)
|
745 |
+
parser.add_argument("-R", "--recognitionTimeout",
|
746 |
+
default=recognitionTimeout,
|
747 |
+
help="In seconds. Amount of time to train the recognition model on each iteration. Defaults to enumeration timeout.",
|
748 |
+
type=int)
|
749 |
+
parser.add_argument("-RS", "--recognitionSteps",
|
750 |
+
default=None,
|
751 |
+
help="Number of gradient steps to train the recognition model. Can be specified instead of train time.",
|
752 |
+
type=int)
|
753 |
+
parser.add_argument(
|
754 |
+
"-k",
|
755 |
+
"--topK",
|
756 |
+
default=topK,
|
757 |
+
help="When training generative and discriminative models, we train them to fit the top K programs. Ideally we would train them to fit the entire frontier, but this is often intractable. default: %d" %
|
758 |
+
topK,
|
759 |
+
type=int)
|
760 |
+
parser.add_argument("-p", "--pseudoCounts",
|
761 |
+
default=pseudoCounts,
|
762 |
+
help="default: %f" % pseudoCounts,
|
763 |
+
type=float)
|
764 |
+
parser.add_argument("-b", "--aic",
|
765 |
+
default=aic,
|
766 |
+
help="default: %f" % aic,
|
767 |
+
type=float)
|
768 |
+
parser.add_argument("-l", "--structurePenalty",
|
769 |
+
default=structurePenalty,
|
770 |
+
help="default: %f" % structurePenalty,
|
771 |
+
type=float)
|
772 |
+
parser.add_argument("-a", "--arity",
|
773 |
+
default=a,
|
774 |
+
help="default: %d" % a,
|
775 |
+
type=int)
|
776 |
+
parser.add_argument("-c", "--CPUs",
|
777 |
+
default=CPUs,
|
778 |
+
help="default: %d" % CPUs,
|
779 |
+
type=int)
|
780 |
+
parser.add_argument("--no-cuda",
|
781 |
+
action="store_false",
|
782 |
+
dest="cuda",
|
783 |
+
help="""cuda will be used if available (which it %s),
|
784 |
+
unless this is set""" % ("IS" if cuda else "ISN'T"))
|
785 |
+
parser.add_argument("-m", "--maximumFrontier",
|
786 |
+
help="""Even though we enumerate --frontierSize
|
787 |
+
programs, we might want to only keep around the very
|
788 |
+
best for performance reasons. This is a cut off on the
|
789 |
+
maximum size of the frontier that is kept around.
|
790 |
+
Default: %s""" % maximumFrontier,
|
791 |
+
type=int)
|
792 |
+
parser.add_argument("--reuseRecognition",
|
793 |
+
help="""Should we initialize recognition model weights to be what they were at the previous DreamCoder iteration? Default: %s""" % reuseRecognition,
|
794 |
+
default=reuseRecognition,
|
795 |
+
action="store_true")
|
796 |
+
parser.add_argument("--recognition",
|
797 |
+
dest="useRecognitionModel",
|
798 |
+
action="store_true",
|
799 |
+
help="""Enable bottom-up neural recognition model.
|
800 |
+
Default: %s""" % useRecognitionModel)
|
801 |
+
parser.add_argument("--ensembleSize",
|
802 |
+
dest="ensembleSize",
|
803 |
+
default=1,
|
804 |
+
help="Number of recognition models to train and enumerate from at each iteration.",
|
805 |
+
type=int)
|
806 |
+
parser.add_argument("-g", "--no-recognition",
|
807 |
+
dest="useRecognitionModel",
|
808 |
+
action="store_false",
|
809 |
+
help="""Disable bottom-up neural recognition model.
|
810 |
+
Default: %s""" % (not useRecognitionModel))
|
811 |
+
parser.add_argument("-d", "--no-dsl",
|
812 |
+
dest="useDSL",
|
813 |
+
action="store_false",
|
814 |
+
help="""Disable DSL enumeration and updating.""")
|
815 |
+
parser.add_argument("--no-consolidation",
|
816 |
+
dest="noConsolidation",
|
817 |
+
action="store_true",
|
818 |
+
help="""Disable DSL updating.""")
|
819 |
+
parser.add_argument(
|
820 |
+
"--testingTimeout",
|
821 |
+
type=int,
|
822 |
+
dest="testingTimeout",
|
823 |
+
default=0,
|
824 |
+
help="Number of seconds we should spend evaluating on each held out testing task.")
|
825 |
+
parser.add_argument(
|
826 |
+
"--testEvery",
|
827 |
+
type=int,
|
828 |
+
dest="testEvery",
|
829 |
+
default=1,
|
830 |
+
help="Run heldout testing every X iterations."
|
831 |
+
)
|
832 |
+
parser.add_argument(
|
833 |
+
"--seed",
|
834 |
+
type=int,
|
835 |
+
default=0,
|
836 |
+
help="Random seed. Currently this only matters for random batching strategies.")
|
837 |
+
parser.add_argument(
|
838 |
+
"--activation",
|
839 |
+
choices=[
|
840 |
+
"relu",
|
841 |
+
"sigmoid",
|
842 |
+
"tanh"],
|
843 |
+
default=activation,
|
844 |
+
help="""Activation function for neural recognition model.
|
845 |
+
Default: %s""" %
|
846 |
+
activation)
|
847 |
+
parser.add_argument(
|
848 |
+
"--solver",
|
849 |
+
choices=[
|
850 |
+
"ocaml",
|
851 |
+
"pypy",
|
852 |
+
"python"],
|
853 |
+
default=solver,
|
854 |
+
help="""Solver for enumeration.
|
855 |
+
Default: %s""" %
|
856 |
+
solver)
|
857 |
+
parser.add_argument(
|
858 |
+
"-r",
|
859 |
+
"--Helmholtz",
|
860 |
+
dest="helmholtzRatio",
|
861 |
+
help="""When training recognition models, what fraction of the training data should be samples from the generative model? Default %f""" %
|
862 |
+
helmholtzRatio,
|
863 |
+
default=helmholtzRatio,
|
864 |
+
type=float)
|
865 |
+
parser.add_argument(
|
866 |
+
"--compressor",
|
867 |
+
default=compressor,
|
868 |
+
choices=["pypy","rust","vs","pypy_vs","ocaml","memorize"])
|
869 |
+
parser.add_argument(
|
870 |
+
"--matrixRank",
|
871 |
+
help="Maximum rank of bigram transition matrix for contextual recognition model. Defaults to full rank.",
|
872 |
+
default=None,
|
873 |
+
type=int)
|
874 |
+
parser.add_argument(
|
875 |
+
"--mask",
|
876 |
+
help="Unconditional bigram masking",
|
877 |
+
default=False, action="store_true")
|
878 |
+
parser.add_argument("--biasOptimal",
|
879 |
+
help="Enumerate dreams rather than sample them & bias-optimal recognition objective",
|
880 |
+
default=False, action="store_true")
|
881 |
+
parser.add_argument("--contextual",
|
882 |
+
help="bigram recognition model (default is unigram model)",
|
883 |
+
default=False, action="store_true")
|
884 |
+
parser.add_argument("--clear-recognition",
|
885 |
+
dest="clear-recognition",
|
886 |
+
help="Clears the recognition model from a checkpoint. Necessary for graphing results with recognition models, because pickle is kind of stupid sometimes.",
|
887 |
+
default=None,
|
888 |
+
type=str)
|
889 |
+
parser.add_argument("--primitive-graph",
|
890 |
+
dest="primitive-graph",
|
891 |
+
nargs='+',
|
892 |
+
help="Displays a dependency graph of the learned primitives",
|
893 |
+
default=None,
|
894 |
+
type=str)
|
895 |
+
parser.add_argument(
|
896 |
+
"--taskBatchSize",
|
897 |
+
dest="taskBatchSize",
|
898 |
+
help="Number of tasks to train on during wake. Defaults to all tasks if None.",
|
899 |
+
default=None,
|
900 |
+
type=int)
|
901 |
+
parser.add_argument(
|
902 |
+
"--taskReranker",
|
903 |
+
dest="taskReranker",
|
904 |
+
help="Reranking function used to order the tasks we train on during waking.",
|
905 |
+
choices=[
|
906 |
+
"default",
|
907 |
+
"random",
|
908 |
+
"randomShuffle",
|
909 |
+
"unsolved",
|
910 |
+
"unsolvedEntropy",
|
911 |
+
"unsolvedRandomEntropy",
|
912 |
+
"randomkNN",
|
913 |
+
"randomLowEntropykNN"],
|
914 |
+
default=taskReranker,
|
915 |
+
type=str)
|
916 |
+
parser.add_argument(
|
917 |
+
"--storeTaskMetrics",
|
918 |
+
dest="storeTaskMetrics",
|
919 |
+
default=True,
|
920 |
+
help="Whether to store task metrics directly in the ECResults.",
|
921 |
+
action="store_true"
|
922 |
+
)
|
923 |
+
parser.add_argument(
|
924 |
+
"--rewriteTaskMetrics",
|
925 |
+
dest="rewriteTaskMetrics",
|
926 |
+
help="Whether to rewrite a new task metrics dictionary at each iteration, rather than retaining the old.",
|
927 |
+
action="store_true"
|
928 |
+
)
|
929 |
+
parser.add_argument("--addTaskMetrics",
|
930 |
+
dest="addTaskMetrics",
|
931 |
+
help="Creates a checkpoint with task metrics and no recognition model for graphing.",
|
932 |
+
default=None,
|
933 |
+
nargs='+',
|
934 |
+
type=str)
|
935 |
+
parser.add_argument("--auxiliary",
|
936 |
+
action="store_true", default=False,
|
937 |
+
help="Add auxiliary classification loss to recognition network training",
|
938 |
+
dest="auxiliaryLoss")
|
939 |
+
parser.add_argument("--addFullTaskMetrics",
|
940 |
+
help="Only to be used in conjunction with --resume. Loads checkpoint, solves both testing and training tasks, stores frontiers, solve times, and task metrics, and then dies.",
|
941 |
+
default=False,
|
942 |
+
action="store_true")
|
943 |
+
parser.add_argument("--countParameters",
|
944 |
+
help="Load a checkpoint then report how many parameters are in the recognition model.",
|
945 |
+
default=None, type=str)
|
946 |
+
parser.set_defaults(useRecognitionModel=useRecognitionModel,
|
947 |
+
useDSL=True,
|
948 |
+
featureExtractor=featureExtractor,
|
949 |
+
maximumFrontier=maximumFrontier,
|
950 |
+
cuda=cuda)
|
951 |
+
if extras is not None:
|
952 |
+
extras(parser)
|
953 |
+
v = vars(parser.parse_args())
|
954 |
+
if v["clear-recognition"] is not None:
|
955 |
+
ECResult.clearRecognitionModel(v["clear-recognition"])
|
956 |
+
sys.exit(0)
|
957 |
+
else:
|
958 |
+
del v["clear-recognition"]
|
959 |
+
|
960 |
+
if v["primitive-graph"] is not None:
|
961 |
+
|
962 |
+
for n,pg in enumerate(v["primitive-graph"]):
|
963 |
+
with open(pg,'rb') as handle:
|
964 |
+
result = dill.load(handle)
|
965 |
+
graphPrimitives(result,f"figures/deepProgramLearning/{sys.argv[0]}{n}",view=True)
|
966 |
+
sys.exit(0)
|
967 |
+
else:
|
968 |
+
del v["primitive-graph"]
|
969 |
+
|
970 |
+
if v["addTaskMetrics"] is not None:
|
971 |
+
for path in v["addTaskMetrics"]:
|
972 |
+
with open(path,'rb') as handle:
|
973 |
+
result = dill.load(handle)
|
974 |
+
addTaskMetrics(result, path)
|
975 |
+
sys.exit(0)
|
976 |
+
else:
|
977 |
+
del v["addTaskMetrics"]
|
978 |
+
|
979 |
+
if v["useRecognitionModel"] and v["recognitionTimeout"] is None:
|
980 |
+
v["recognitionTimeout"] = v["enumerationTimeout"]
|
981 |
+
|
982 |
+
if v["countParameters"]:
|
983 |
+
with open(v["countParameters"], "rb") as handle:
|
984 |
+
result = dill.load(handle)
|
985 |
+
eprint("The recognition model has",
|
986 |
+
sum(p.numel() for p in result.recognitionModel.parameters() if p.requires_grad),
|
987 |
+
"trainable parameters and",
|
988 |
+
sum(p.numel() for p in result.recognitionModel.parameters() ),
|
989 |
+
"total parameters.\n",
|
990 |
+
"The feature extractor accounts for",
|
991 |
+
sum(p.numel() for p in result.recognitionModel.featureExtractor.parameters() ),
|
992 |
+
"of those parameters.\n",
|
993 |
+
"The grammar builder accounts for",
|
994 |
+
sum(p.numel() for p in result.recognitionModel.grammarBuilder.parameters() ),
|
995 |
+
"of those parameters.\n")
|
996 |
+
sys.exit(0)
|
997 |
+
del v["countParameters"]
|
998 |
+
|
999 |
+
|
1000 |
+
return v
|
1001 |
+
|
1002 |
+
def addTaskMetrics(result, path):
|
1003 |
+
"""Adds a task metrics to ECResults that were pickled without them."""
|
1004 |
+
with torch.no_grad(): return addTaskMetrics_(result, path)
|
1005 |
+
def addTaskMetrics_(result, path):
|
1006 |
+
SUFFIX = '.pickle'
|
1007 |
+
assert path.endswith(SUFFIX)
|
1008 |
+
|
1009 |
+
tasks = result.taskSolutions.keys()
|
1010 |
+
everyTask = set(tasks)
|
1011 |
+
for t in result.recognitionTaskMetrics:
|
1012 |
+
if isinstance(t, Task) and t not in everyTask: everyTask.add(t)
|
1013 |
+
|
1014 |
+
eprint(f"Found {len(tasks)} training tasks.")
|
1015 |
+
eprint(f"Scrounged up {len(everyTask) - len(tasks)} testing tasks.")
|
1016 |
+
if not hasattr(result, "recognitionTaskMetrics") or result.recognitionTaskMetrics is None:
|
1017 |
+
result.recognitionTaskMetrics = {}
|
1018 |
+
|
1019 |
+
# If task has images, store them.
|
1020 |
+
if hasattr(list(tasks)[0], 'getImage'):
|
1021 |
+
images = {t: t.getImage(pretty=True) for t in tasks}
|
1022 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, images, 'taskImages')
|
1023 |
+
|
1024 |
+
if hasattr(list(tasks)[0], 'highresolution'):
|
1025 |
+
images = {t: t.highresolution for t in tasks}
|
1026 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, images, 'taskImages')
|
1027 |
+
|
1028 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.auxiliaryPrimitiveEmbeddings(), 'auxiliaryPrimitiveEmbeddings')
|
1029 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskAuxiliaryLossLayer(tasks), 'taskAuxiliaryLossLayer')
|
1030 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskAuxiliaryLossLayer(everyTask), 'every_auxiliaryLossLayer')
|
1031 |
+
|
1032 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarFeatureLogProductions(tasks), 'grammarFeatureLogProductions')
|
1033 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarFeatureLogProductions(everyTask), 'every_grammarFeatureLogProductions')
|
1034 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarLogProductions(tasks), 'contextualLogProductions')
|
1035 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarLogProductions(everyTask), 'every_contextualLogProductions')
|
1036 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskHiddenStates(tasks), 'hiddenState')
|
1037 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskHiddenStates(everyTask), 'every_hiddenState')
|
1038 |
+
g = result.grammars[-2] # the final entry in result.grammars is a grammar that we have not used yet
|
1039 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, {f.task: f.expectedProductionUses(g)
|
1040 |
+
for f in result.taskSolutions.values()
|
1041 |
+
if len(f) > 0},
|
1042 |
+
'expectedProductionUses')
|
1043 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics, {f.task: f.expectedProductionUses(g)
|
1044 |
+
for t, metrics in result.recognitionTaskMetrics.items()
|
1045 |
+
if "frontier" in metrics
|
1046 |
+
for f in [metrics["frontier"]]
|
1047 |
+
if len(f) > 0},
|
1048 |
+
'every_expectedProductionUses')
|
1049 |
+
if False:
|
1050 |
+
eprint(f"About to do an expensive Monte Carlo simulation w/ {len(tasks)} tasks")
|
1051 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics,
|
1052 |
+
{task: result.recognitionModel.grammarOfTask(task).untorch().expectedUsesMonteCarlo(task.request, debug=False)
|
1053 |
+
for task in tasks },
|
1054 |
+
'expectedProductionUsesMonteCarlo')
|
1055 |
+
try:
|
1056 |
+
updateTaskSummaryMetrics(result.recognitionTaskMetrics,
|
1057 |
+
result.recognitionModel.taskGrammarStartProductions(tasks),
|
1058 |
+
'startProductions')
|
1059 |
+
except: pass # can fail if we do not have a contextual model
|
1060 |
+
|
1061 |
+
#updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarLogProductions(tasks), 'task_no_parent_log_productions')
|
1062 |
+
#updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarEntropies(tasks), 'taskGrammarEntropies')
|
1063 |
+
|
1064 |
+
result.recognitionModel = None
|
1065 |
+
|
1066 |
+
clearedPath = path[:-len(SUFFIX)] + "_graph=True" + SUFFIX
|
1067 |
+
with open(clearedPath,'wb') as handle:
|
1068 |
+
result = dill.dump(result, handle)
|
1069 |
+
eprint(" [+] Cleared recognition model from:")
|
1070 |
+
eprint(" %s"%path)
|
1071 |
+
eprint(" and exported to:")
|
1072 |
+
eprint(" %s"%clearedPath)
|
1073 |
+
eprint(" Use this one for graphing.")
|
1074 |
+
|
dreamcoder/dreaming.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
from dreamcoder.domains.arithmetic.arithmeticPrimitives import k1, k0, addition, subtraction, multiplication
|
6 |
+
from dreamcoder.frontier import Frontier, FrontierEntry
|
7 |
+
from dreamcoder.grammar import Grammar
|
8 |
+
from dreamcoder.program import Program
|
9 |
+
from dreamcoder.task import Task
|
10 |
+
from dreamcoder.type import arrow, tint
|
11 |
+
from dreamcoder.utilities import tuplify, timing, eprint, get_root_dir, mean
|
12 |
+
|
13 |
+
|
14 |
+
def helmholtzEnumeration(g, request, inputs, timeout, _=None,
|
15 |
+
special=None, evaluationTimeout=None):
|
16 |
+
"""Returns json (as text)"""
|
17 |
+
message = {"request": request.json(),
|
18 |
+
"timeout": timeout,
|
19 |
+
"DSL": g.json(),
|
20 |
+
"extras": inputs}
|
21 |
+
if evaluationTimeout: message["evaluationTimeout"] = evaluationTimeout
|
22 |
+
if special: message["special"] = special
|
23 |
+
message = json.dumps(message)
|
24 |
+
with open('/tmp/hm', 'w') as handle:
|
25 |
+
handle.write(message)
|
26 |
+
try:
|
27 |
+
binary = os.path.join(get_root_dir(), 'helmholtz')
|
28 |
+
process = subprocess.Popen(binary,
|
29 |
+
stdin=subprocess.PIPE,
|
30 |
+
stdout=subprocess.PIPE)
|
31 |
+
response, error = process.communicate(bytes(message, encoding="utf-8"))
|
32 |
+
except OSError as exc:
|
33 |
+
raise exc
|
34 |
+
return response
|
35 |
+
|
36 |
+
|
37 |
+
def backgroundHelmholtzEnumeration(tasks, g, timeout, _=None,
|
38 |
+
special=None, evaluationTimeout=None):
|
39 |
+
from pathos.multiprocessing import Pool
|
40 |
+
requests = list({t.request for t in tasks})
|
41 |
+
inputs = {r: list({tuplify(xs)
|
42 |
+
for t in tasks if t.request == r
|
43 |
+
for xs, y in t.examples})
|
44 |
+
for r in requests}
|
45 |
+
workers = Pool(len(requests))
|
46 |
+
promises = [workers.apply_async(helmholtzEnumeration,
|
47 |
+
args=(g, r, inputs[r], float(timeout)),
|
48 |
+
kwds={'special': special,
|
49 |
+
'evaluationTimeout': evaluationTimeout})
|
50 |
+
for r in requests]
|
51 |
+
|
52 |
+
def get():
|
53 |
+
results = [p.get() for p in promises]
|
54 |
+
frontiers = []
|
55 |
+
with timing("(Helmholtz enumeration) Decoded json into frontiers"):
|
56 |
+
for request, result in zip(requests, results):
|
57 |
+
response = json.loads(result.decode("utf-8"))
|
58 |
+
for b, entry in enumerate(response):
|
59 |
+
frontiers.append(Frontier([FrontierEntry(program=Program.parse(p),
|
60 |
+
logPrior=entry["ll"],
|
61 |
+
logLikelihood=0.)
|
62 |
+
for p in entry["programs"]],
|
63 |
+
task=Task(str(b),
|
64 |
+
request,
|
65 |
+
[])))
|
66 |
+
eprint("Total number of Helmholtz frontiers:", len(frontiers))
|
67 |
+
return frontiers
|
68 |
+
|
69 |
+
return get
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
from dreamcoder.recognition import RecognitionModel, DummyFeatureExtractor
|
74 |
+
g = Grammar.uniform([k1, k0, addition, subtraction, multiplication])
|
75 |
+
frontiers = helmholtzEnumeration(g,
|
76 |
+
arrow(tint, tint),
|
77 |
+
[[0], [1], [2]],
|
78 |
+
10.)
|
79 |
+
eprint("average frontier size", mean(len(f.entries) for f in frontiers))
|
80 |
+
f = DummyFeatureExtractor([])
|
81 |
+
r = RecognitionModel(f, g, hidden=[], contextual=True)
|
82 |
+
r.trainBiasOptimal(frontiers, frontiers, steps=70)
|
83 |
+
g = r.grammarOfTask(frontiers[0].task).untorch()
|
84 |
+
frontiers = helmholtzEnumeration(g,
|
85 |
+
arrow(tint, tint),
|
86 |
+
[[0], [1], [2]],
|
87 |
+
10.)
|
88 |
+
for f in frontiers:
|
89 |
+
eprint(f.summarizeFull())
|
90 |
+
eprint("average frontier size", mean(len(f.entries) for f in frontiers))
|
dreamcoder/ec.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
print("DEPRECATION NOTICE: this module (dreamcoder.ec) will be deleted soon, "
|
2 |
+
"please update your code to import from dreamcoder.dreamcoder instead")
|
3 |
+
from dreamcoder.dreamcoder import * # noqa
|
dreamcoder/enumeration.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.likelihoodModel import AllOrNothingLikelihoodModel
|
2 |
+
from dreamcoder.grammar import *
|
3 |
+
from dreamcoder.utilities import get_root_dir
|
4 |
+
|
5 |
+
import os
|
6 |
+
import traceback
|
7 |
+
import subprocess
|
8 |
+
|
9 |
+
|
10 |
+
def multicoreEnumeration(g, tasks, _=None,
|
11 |
+
enumerationTimeout=None,
|
12 |
+
solver='ocaml',
|
13 |
+
CPUs=1,
|
14 |
+
maximumFrontier=None,
|
15 |
+
verbose=True,
|
16 |
+
evaluationTimeout=None,
|
17 |
+
testing=False):
|
18 |
+
'''g: Either a Grammar, or a map from task to grammar.
|
19 |
+
Returns (list-of-frontiers, map-from-task-to-search-time)'''
|
20 |
+
|
21 |
+
# We don't use actual threads but instead use the multiprocessing
|
22 |
+
# library. This is because we need to be able to kill workers.
|
23 |
+
#from multiprocess import Process, Queue
|
24 |
+
|
25 |
+
from multiprocessing import Queue
|
26 |
+
|
27 |
+
# everything that gets sent between processes will be dilled
|
28 |
+
import dill
|
29 |
+
|
30 |
+
solvers = {"ocaml": solveForTask_ocaml,
|
31 |
+
"pypy": solveForTask_pypy,
|
32 |
+
"python": solveForTask_python}
|
33 |
+
assert solver in solvers, "You must specify a valid solver. options are ocaml, pypy, or python."
|
34 |
+
|
35 |
+
likelihoodModel = None
|
36 |
+
if solver == 'pypy' or solver == 'python':
|
37 |
+
# Use an all or nothing likelihood model.
|
38 |
+
likelihoodModel = AllOrNothingLikelihoodModel(timeout=evaluationTimeout)
|
39 |
+
|
40 |
+
solver = solvers[solver]
|
41 |
+
|
42 |
+
if not isinstance(g, dict):
|
43 |
+
g = {t: g for t in tasks}
|
44 |
+
task2grammar = g
|
45 |
+
|
46 |
+
# If we are not evaluating on held out testing tasks:
|
47 |
+
# Bin the tasks by request type and grammar
|
48 |
+
# If these are the same then we can enumerate for multiple tasks simultaneously
|
49 |
+
# If we are evaluating testing tasks:
|
50 |
+
# Make sure that each job corresponds to exactly one task
|
51 |
+
jobs = {}
|
52 |
+
for i, t in enumerate(tasks):
|
53 |
+
if testing:
|
54 |
+
k = (task2grammar[t], t.request, i)
|
55 |
+
else:
|
56 |
+
k = (task2grammar[t], t.request)
|
57 |
+
jobs[k] = jobs.get(k, []) + [t]
|
58 |
+
|
59 |
+
disableParallelism = len(jobs) == 1
|
60 |
+
parallelCallback = launchParallelProcess if not disableParallelism else lambda f, * \
|
61 |
+
a, **k: f(*a, **k)
|
62 |
+
if disableParallelism:
|
63 |
+
eprint("Disabling parallelism on the Python side because we only have one job.")
|
64 |
+
eprint("If you are using ocaml, there could still be parallelism.")
|
65 |
+
|
66 |
+
# Map from task to the shortest time to find a program solving it
|
67 |
+
bestSearchTime = {t: None for t in task2grammar}
|
68 |
+
|
69 |
+
lowerBounds = {k: 0. for k in jobs}
|
70 |
+
|
71 |
+
frontiers = {t: Frontier([], task=t) for t in task2grammar}
|
72 |
+
|
73 |
+
# For each job we keep track of how long we have been working on it
|
74 |
+
stopwatches = {t: Stopwatch() for t in jobs}
|
75 |
+
|
76 |
+
# Map from task to how many programs we enumerated for that task
|
77 |
+
taskToNumberOfPrograms = {t: 0 for t in tasks }
|
78 |
+
|
79 |
+
def numberOfHits(f):
|
80 |
+
return sum(e.logLikelihood > -0.01 for e in f)
|
81 |
+
|
82 |
+
def budgetIncrement(lb):
|
83 |
+
if True:
|
84 |
+
return 1.5
|
85 |
+
# Very heuristic - not sure what to do here
|
86 |
+
if lb < 24.:
|
87 |
+
return 1.
|
88 |
+
elif lb < 27.:
|
89 |
+
return 0.5
|
90 |
+
else:
|
91 |
+
return 0.25
|
92 |
+
|
93 |
+
def maximumFrontiers(j):
|
94 |
+
tasks = jobs[j]
|
95 |
+
return {t: maximumFrontier - numberOfHits(frontiers[t]) for t in tasks}
|
96 |
+
|
97 |
+
def allocateCPUs(n, tasks):
|
98 |
+
allocation = {t: 0 for t in tasks}
|
99 |
+
while n > 0:
|
100 |
+
for t in tasks:
|
101 |
+
# During testing we use exactly one CPU per task
|
102 |
+
if testing and allocation[t] > 0:
|
103 |
+
return allocation
|
104 |
+
allocation[t] += 1
|
105 |
+
n -= 1
|
106 |
+
if n == 0:
|
107 |
+
break
|
108 |
+
return allocation
|
109 |
+
|
110 |
+
def refreshJobs():
|
111 |
+
for k in list(jobs.keys()):
|
112 |
+
v = [t for t in jobs[k]
|
113 |
+
if numberOfHits(frontiers[t]) < maximumFrontier
|
114 |
+
and stopwatches[k].elapsed <= enumerationTimeout]
|
115 |
+
if v:
|
116 |
+
jobs[k] = v
|
117 |
+
else:
|
118 |
+
del jobs[k]
|
119 |
+
|
120 |
+
# Workers put their messages in here
|
121 |
+
q = Queue()
|
122 |
+
|
123 |
+
# How many CPUs are we using?
|
124 |
+
activeCPUs = 0
|
125 |
+
|
126 |
+
# How many CPUs was each job allocated?
|
127 |
+
id2CPUs = {}
|
128 |
+
# What job was each ID working on?
|
129 |
+
id2job = {}
|
130 |
+
nextID = 0
|
131 |
+
|
132 |
+
while True:
|
133 |
+
refreshJobs()
|
134 |
+
# Don't launch a job that we are already working on
|
135 |
+
# We run the stopwatch whenever the job is being worked on
|
136 |
+
# freeJobs are things that we are not working on but could be
|
137 |
+
freeJobs = [j for j in jobs if not stopwatches[j].running
|
138 |
+
and stopwatches[j].elapsed < enumerationTimeout - 0.5]
|
139 |
+
if freeJobs and activeCPUs < CPUs:
|
140 |
+
# Allocate a CPU to each of the jobs that we have made the least
|
141 |
+
# progress on
|
142 |
+
freeJobs.sort(key=lambda j: lowerBounds[j])
|
143 |
+
# Launch some more jobs until all of the CPUs are being used
|
144 |
+
availableCPUs = CPUs - activeCPUs
|
145 |
+
allocation = allocateCPUs(availableCPUs, freeJobs)
|
146 |
+
for j in freeJobs:
|
147 |
+
if allocation[j] == 0:
|
148 |
+
continue
|
149 |
+
g, request = j[:2]
|
150 |
+
bi = budgetIncrement(lowerBounds[j])
|
151 |
+
thisTimeout = enumerationTimeout - stopwatches[j].elapsed
|
152 |
+
eprint("(python) Launching %s (%d tasks) w/ %d CPUs. %f <= MDL < %f. Timeout %f." %
|
153 |
+
(request, len(jobs[j]), allocation[j], lowerBounds[j], lowerBounds[j] + bi, thisTimeout))
|
154 |
+
stopwatches[j].start()
|
155 |
+
parallelCallback(wrapInThread(solver),
|
156 |
+
q=q, g=g, ID=nextID,
|
157 |
+
elapsedTime=stopwatches[j].elapsed,
|
158 |
+
CPUs=allocation[j],
|
159 |
+
tasks=jobs[j],
|
160 |
+
lowerBound=lowerBounds[j],
|
161 |
+
upperBound=lowerBounds[j] + bi,
|
162 |
+
budgetIncrement=bi,
|
163 |
+
timeout=thisTimeout,
|
164 |
+
evaluationTimeout=evaluationTimeout,
|
165 |
+
maximumFrontiers=maximumFrontiers(j),
|
166 |
+
testing=testing,
|
167 |
+
likelihoodModel=likelihoodModel)
|
168 |
+
id2CPUs[nextID] = allocation[j]
|
169 |
+
id2job[nextID] = j
|
170 |
+
nextID += 1
|
171 |
+
|
172 |
+
activeCPUs += allocation[j]
|
173 |
+
lowerBounds[j] += bi
|
174 |
+
|
175 |
+
# If nothing is running, and we just tried to launch jobs,
|
176 |
+
# then that means we are finished
|
177 |
+
if all(not s.running for s in stopwatches.values()):
|
178 |
+
break
|
179 |
+
|
180 |
+
# Wait to get a response
|
181 |
+
message = Bunch(dill.loads(q.get()))
|
182 |
+
|
183 |
+
if message.result == "failure":
|
184 |
+
eprint("PANIC! Exception in child worker:", message.exception)
|
185 |
+
eprint(message.stacktrace)
|
186 |
+
assert False
|
187 |
+
elif message.result == "success":
|
188 |
+
# Mark the CPUs is no longer being used and pause the stopwatch
|
189 |
+
activeCPUs -= id2CPUs[message.ID]
|
190 |
+
stopwatches[id2job[message.ID]].stop()
|
191 |
+
|
192 |
+
newFrontiers, searchTimes, pc = message.value
|
193 |
+
for t, f in newFrontiers.items():
|
194 |
+
oldBest = None if len(
|
195 |
+
frontiers[t]) == 0 else frontiers[t].bestPosterior
|
196 |
+
frontiers[t] = frontiers[t].combine(f)
|
197 |
+
newBest = None if len(
|
198 |
+
frontiers[t]) == 0 else frontiers[t].bestPosterior
|
199 |
+
|
200 |
+
taskToNumberOfPrograms[t] += pc
|
201 |
+
|
202 |
+
dt = searchTimes[t]
|
203 |
+
if dt is not None:
|
204 |
+
if bestSearchTime[t] is None:
|
205 |
+
bestSearchTime[t] = dt
|
206 |
+
else:
|
207 |
+
# newBest & oldBest should both be defined
|
208 |
+
assert oldBest is not None
|
209 |
+
assert newBest is not None
|
210 |
+
newScore = newBest.logPrior + newBest.logLikelihood
|
211 |
+
oldScore = oldBest.logPrior + oldBest.logLikelihood
|
212 |
+
|
213 |
+
if newScore > oldScore:
|
214 |
+
bestSearchTime[t] = dt
|
215 |
+
elif newScore == oldScore:
|
216 |
+
bestSearchTime[t] = min(bestSearchTime[t], dt)
|
217 |
+
else:
|
218 |
+
eprint("Unknown message result:", message.result)
|
219 |
+
assert False
|
220 |
+
|
221 |
+
eprint("We enumerated this many programs, for each task:\n\t",
|
222 |
+
list(taskToNumberOfPrograms.values()))
|
223 |
+
|
224 |
+
return [frontiers[t] for t in tasks], bestSearchTime
|
225 |
+
|
226 |
+
def wrapInThread(f):
|
227 |
+
"""
|
228 |
+
Returns a function that is designed to be run in a thread/threadlike process.
|
229 |
+
Result will be either put into the q
|
230 |
+
"""
|
231 |
+
import dill
|
232 |
+
|
233 |
+
def _f(*a, **k):
|
234 |
+
q = k.pop("q")
|
235 |
+
ID = k.pop("ID")
|
236 |
+
|
237 |
+
try:
|
238 |
+
r = f(*a, **k)
|
239 |
+
q.put(dill.dumps({"result": "success",
|
240 |
+
"ID": ID,
|
241 |
+
"value": r}))
|
242 |
+
except Exception as e:
|
243 |
+
q.put(dill.dumps({"result": "failure",
|
244 |
+
"exception": e,
|
245 |
+
"stacktrace": traceback.format_exc(),
|
246 |
+
"ID": ID}))
|
247 |
+
return
|
248 |
+
return _f
|
249 |
+
|
250 |
+
|
251 |
+
def solveForTask_ocaml(_=None,
|
252 |
+
elapsedTime=0.,
|
253 |
+
CPUs=1,
|
254 |
+
g=None, tasks=None,
|
255 |
+
lowerBound=None, upperBound=None, budgetIncrement=None,
|
256 |
+
timeout=None,
|
257 |
+
testing=None, # FIXME: unused
|
258 |
+
likelihoodModel=None,
|
259 |
+
evaluationTimeout=None, maximumFrontiers=None):
|
260 |
+
|
261 |
+
import json
|
262 |
+
|
263 |
+
def taskMessage(t):
|
264 |
+
m = {
|
265 |
+
"examples": [{"inputs": list(xs), "output": y} for xs, y in t.examples],
|
266 |
+
"name": t.name,
|
267 |
+
"request": t.request.json(),
|
268 |
+
"maximumFrontier": maximumFrontiers[t]}
|
269 |
+
if hasattr(t, "specialTask"):
|
270 |
+
special, extra = t.specialTask
|
271 |
+
m["specialTask"] = special
|
272 |
+
m["extras"] = extra
|
273 |
+
return m
|
274 |
+
|
275 |
+
|
276 |
+
message = {"DSL": g.json(),
|
277 |
+
"tasks": [taskMessage(t)
|
278 |
+
for t in tasks],
|
279 |
+
|
280 |
+
"programTimeout": evaluationTimeout,
|
281 |
+
"nc": CPUs,
|
282 |
+
"timeout": timeout,
|
283 |
+
"lowerBound": lowerBound,
|
284 |
+
"upperBound": upperBound,
|
285 |
+
"budgetIncrement": budgetIncrement,
|
286 |
+
"verbose": False,
|
287 |
+
"shatter": 5 if len(tasks) == 1 and "turtle" in str(tasks[0].request) else 10}
|
288 |
+
|
289 |
+
if hasattr(tasks[0], 'maxParameters') and tasks[0].maxParameters is not None:
|
290 |
+
message["maxParameters"] = tasks[0].maxParameters
|
291 |
+
|
292 |
+
message = json.dumps(message)
|
293 |
+
# uncomment this if you want to save the messages being sent to the solver
|
294 |
+
|
295 |
+
|
296 |
+
try:
|
297 |
+
solver_file = os.path.join(get_root_dir(), 'solver')
|
298 |
+
process = subprocess.Popen(solver_file,
|
299 |
+
stdin=subprocess.PIPE,
|
300 |
+
stdout=subprocess.PIPE)
|
301 |
+
response, error = process.communicate(bytes(message, encoding="utf-8"))
|
302 |
+
response = json.loads(response.decode("utf-8"))
|
303 |
+
except OSError as exc:
|
304 |
+
raise exc
|
305 |
+
|
306 |
+
except:
|
307 |
+
print("response:", response)
|
308 |
+
print("error:", error)
|
309 |
+
with open("message", "w") as f:
|
310 |
+
f.write(message)
|
311 |
+
print("message,", message)
|
312 |
+
assert False, "MAX RAISE"
|
313 |
+
|
314 |
+
|
315 |
+
pc = response.get("number_enumerated",0) # TODO
|
316 |
+
frontiers = {}
|
317 |
+
searchTimes = {}
|
318 |
+
for t in tasks:
|
319 |
+
solutions = response[t.name]
|
320 |
+
frontier = Frontier([FrontierEntry(program=p,
|
321 |
+
logLikelihood=e["logLikelihood"],
|
322 |
+
logPrior=g.logLikelihood(t.request, p))
|
323 |
+
for e in solutions
|
324 |
+
for p in [Program.parse(e["program"])]],
|
325 |
+
task=t)
|
326 |
+
frontiers[t] = frontier
|
327 |
+
if frontier.empty:
|
328 |
+
searchTimes[t] = None
|
329 |
+
# This is subtle:
|
330 |
+
# The search time we report is actually not be minimum time to find any solution
|
331 |
+
# Rather it is the time to find the MAP solution
|
332 |
+
# This is important for regression problems,
|
333 |
+
# where we might find something with a good prior but bad likelihood early on,
|
334 |
+
# and only later discovered the good high likelihood program
|
335 |
+
else:
|
336 |
+
searchTimes[t] = min(
|
337 |
+
(e["logLikelihood"] + e["logPrior"],
|
338 |
+
e["time"]) for e in solutions)[1] + elapsedTime
|
339 |
+
|
340 |
+
return frontiers, searchTimes, pc
|
341 |
+
|
342 |
+
def solveForTask_pypy(_=None,
|
343 |
+
elapsedTime=0.,
|
344 |
+
g=None, task=None,
|
345 |
+
lowerBound=None, upperBound=None, budgetIncrement=None,
|
346 |
+
timeout=None,
|
347 |
+
likelihoodModel=None,
|
348 |
+
evaluationTimeout=None, maximumFrontier=None, testing=False):
|
349 |
+
return callCompiled(enumerateForTasks,
|
350 |
+
g, tasks, likelihoodModel,
|
351 |
+
timeout=timeout,
|
352 |
+
testing=testing,
|
353 |
+
elapsedTime=elapsedTime,
|
354 |
+
evaluationTimeout=evaluationTimeout,
|
355 |
+
maximumFrontiers=maximumFrontiers,
|
356 |
+
budgetIncrement=budgetIncrement,
|
357 |
+
lowerBound=lowerBound, upperBound=upperBound)
|
358 |
+
|
359 |
+
def solveForTask_python(_=None,
|
360 |
+
elapsedTime=0.,
|
361 |
+
g=None, tasks=None,
|
362 |
+
lowerBound=None, upperBound=None, budgetIncrement=None,
|
363 |
+
timeout=None,
|
364 |
+
CPUs=1,
|
365 |
+
likelihoodModel=None,
|
366 |
+
evaluationTimeout=None, maximumFrontiers=None, testing=False):
|
367 |
+
return enumerateForTasks(g, tasks, likelihoodModel,
|
368 |
+
timeout=timeout,
|
369 |
+
testing=testing,
|
370 |
+
elapsedTime=elapsedTime,
|
371 |
+
evaluationTimeout=evaluationTimeout,
|
372 |
+
maximumFrontiers=maximumFrontiers,
|
373 |
+
budgetIncrement=budgetIncrement,
|
374 |
+
lowerBound=lowerBound, upperBound=upperBound)
|
375 |
+
|
376 |
+
|
377 |
+
class EnumerationTimeout(Exception):
|
378 |
+
pass
|
379 |
+
|
380 |
+
def enumerateForTasks(g, tasks, likelihoodModel, _=None,
|
381 |
+
verbose=False,
|
382 |
+
timeout=None,
|
383 |
+
elapsedTime=0.,
|
384 |
+
CPUs=1,
|
385 |
+
testing=False, #unused
|
386 |
+
evaluationTimeout=None,
|
387 |
+
lowerBound=0.,
|
388 |
+
upperBound=100.,
|
389 |
+
budgetIncrement=1.0, maximumFrontiers=None):
|
390 |
+
assert timeout is not None, \
|
391 |
+
"enumerateForTasks: You must provide a timeout."
|
392 |
+
|
393 |
+
from time import time
|
394 |
+
|
395 |
+
request = tasks[0].request
|
396 |
+
assert all(t.request == request for t in tasks), \
|
397 |
+
"enumerateForTasks: Expected tasks to all have the same type"
|
398 |
+
|
399 |
+
maximumFrontiers = [maximumFrontiers[t] for t in tasks]
|
400 |
+
# store all of the hits in a priority queue
|
401 |
+
# we will never maintain maximumFrontier best solutions
|
402 |
+
hits = [PQ() for _ in tasks]
|
403 |
+
|
404 |
+
starting = time()
|
405 |
+
previousBudget = lowerBound
|
406 |
+
budget = lowerBound + budgetIncrement
|
407 |
+
try:
|
408 |
+
totalNumberOfPrograms = 0
|
409 |
+
while time() < starting + timeout and \
|
410 |
+
any(len(h) < mf for h, mf in zip(hits, maximumFrontiers)) and \
|
411 |
+
budget <= upperBound:
|
412 |
+
numberOfPrograms = 0
|
413 |
+
|
414 |
+
for prior, _, p in g.enumeration(Context.EMPTY, [], request,
|
415 |
+
maximumDepth=99,
|
416 |
+
upperBound=budget,
|
417 |
+
lowerBound=previousBudget):
|
418 |
+
descriptionLength = -prior
|
419 |
+
# Shouldn't see it on this iteration
|
420 |
+
assert descriptionLength <= budget
|
421 |
+
# Should already have seen it
|
422 |
+
assert descriptionLength > previousBudget
|
423 |
+
|
424 |
+
numberOfPrograms += 1
|
425 |
+
totalNumberOfPrograms += 1
|
426 |
+
|
427 |
+
for n in range(len(tasks)):
|
428 |
+
task = tasks[n]
|
429 |
+
|
430 |
+
#Warning:changed to max's new likelihood model situation
|
431 |
+
#likelihood = task.logLikelihood(p, evaluationTimeout)
|
432 |
+
#if invalid(likelihood):
|
433 |
+
#continue
|
434 |
+
success, likelihood = likelihoodModel.score(p, task)
|
435 |
+
if not success:
|
436 |
+
continue
|
437 |
+
|
438 |
+
dt = time() - starting + elapsedTime
|
439 |
+
priority = -(likelihood + prior)
|
440 |
+
hits[n].push(priority,
|
441 |
+
(dt, FrontierEntry(program=p,
|
442 |
+
logLikelihood=likelihood,
|
443 |
+
logPrior=prior)))
|
444 |
+
if len(hits[n]) > maximumFrontiers[n]:
|
445 |
+
hits[n].popMaximum()
|
446 |
+
|
447 |
+
if timeout is not None and time() - starting > timeout:
|
448 |
+
raise EnumerationTimeout
|
449 |
+
|
450 |
+
previousBudget = budget
|
451 |
+
budget += budgetIncrement
|
452 |
+
|
453 |
+
if budget > upperBound:
|
454 |
+
break
|
455 |
+
except EnumerationTimeout:
|
456 |
+
pass
|
457 |
+
frontiers = {tasks[n]: Frontier([e for _, e in hits[n]],
|
458 |
+
task=tasks[n])
|
459 |
+
for n in range(len(tasks))}
|
460 |
+
searchTimes = {
|
461 |
+
tasks[n]: None if len(hits[n]) == 0 else \
|
462 |
+
min(t for t,_ in hits[n]) for n in range(len(tasks))}
|
463 |
+
|
464 |
+
return frontiers, searchTimes, totalNumberOfPrograms
|
465 |
+
|
466 |
+
|
467 |
+
|
468 |
+
|
469 |
+
|
dreamcoder/fragmentGrammar.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.fragmentUtilities import *
|
2 |
+
from dreamcoder.grammar import *
|
3 |
+
from dreamcoder.program import *
|
4 |
+
|
5 |
+
from itertools import chain
|
6 |
+
import time
|
7 |
+
|
8 |
+
|
9 |
+
class FragmentGrammar(object):
|
10 |
+
def __init__(self, logVariable, productions):
|
11 |
+
self.logVariable = logVariable
|
12 |
+
self.productions = productions
|
13 |
+
self.likelihoodCache = {}
|
14 |
+
|
15 |
+
def clearCache(self):
|
16 |
+
self.likelihoodCache = {}
|
17 |
+
|
18 |
+
def __repr__(self):
|
19 |
+
return "FragmentGrammar(logVariable={self.logVariable}, productions={self.productions}".format(
|
20 |
+
self=self)
|
21 |
+
|
22 |
+
def __str__(self):
|
23 |
+
def productionKey(xxx_todo_changeme):
|
24 |
+
(l, t, p) = xxx_todo_changeme
|
25 |
+
return not isinstance(p, Primitive), -l
|
26 |
+
return "\n".join(["%f\tt0\t$_" % self.logVariable] + ["%f\t%s\t%s" % (l, t, p)
|
27 |
+
for l, t, p in sorted(self.productions, key=productionKey)])
|
28 |
+
|
29 |
+
def buildCandidates(self, context, environment, request):
|
30 |
+
candidates = []
|
31 |
+
variableCandidates = []
|
32 |
+
for l, t, p in self.productions:
|
33 |
+
try:
|
34 |
+
newContext, t = t.instantiate(context)
|
35 |
+
newContext = newContext.unify(t.returns(), request)
|
36 |
+
candidates.append((l, newContext,
|
37 |
+
t.apply(newContext),
|
38 |
+
p))
|
39 |
+
except UnificationFailure:
|
40 |
+
continue
|
41 |
+
for j, t in enumerate(environment):
|
42 |
+
try:
|
43 |
+
newContext = context.unify(t.returns(), request)
|
44 |
+
variableCandidates.append((newContext,
|
45 |
+
t.apply(newContext),
|
46 |
+
Index(j)))
|
47 |
+
except UnificationFailure:
|
48 |
+
continue
|
49 |
+
if variableCandidates:
|
50 |
+
z = math.log(len(variableCandidates))
|
51 |
+
for newContext, newType, index in variableCandidates:
|
52 |
+
candidates.append(
|
53 |
+
(self.logVariable - z, newContext, newType, index))
|
54 |
+
|
55 |
+
z = lse([candidate[0] for candidate in candidates])
|
56 |
+
return [(l - z, c, t, p) for l, c, t, p in candidates]
|
57 |
+
|
58 |
+
def logLikelihood(self, request, expression):
|
59 |
+
_, l, _ = self._logLikelihood(Context.EMPTY, [], request, expression)
|
60 |
+
if invalid(l):
|
61 |
+
f = 'failures/likelihoodFailure%s.pickle' % (time() + getPID())
|
62 |
+
eprint("PANIC: Invalid log likelihood. expression:",
|
63 |
+
expression, "tp:", request, "Exported to:", f)
|
64 |
+
with open(f, 'wb') as handle:
|
65 |
+
pickle.dump((self, request, expression), handle)
|
66 |
+
assert False
|
67 |
+
return l
|
68 |
+
|
69 |
+
def closedUses(self, request, expression):
|
70 |
+
_, l, u = self._logLikelihood(Context.EMPTY, [], request, expression)
|
71 |
+
return l, u
|
72 |
+
|
73 |
+
def _logLikelihood(self, context, environment, request, expression):
|
74 |
+
'''returns (context, log likelihood, uses)'''
|
75 |
+
|
76 |
+
# We can cash likelihood calculations faster whenever they don't involve type inference
|
77 |
+
# This is because they are guaranteed to not modify the context,
|
78 |
+
polymorphic = request.isPolymorphic or any(
|
79 |
+
v.isPolymorphic for v in environment)
|
80 |
+
# For some reason polymorphic caching slows it down
|
81 |
+
shouldDoCaching = not polymorphic
|
82 |
+
|
83 |
+
# Caching
|
84 |
+
if shouldDoCaching:
|
85 |
+
if polymorphic:
|
86 |
+
inTypes = canonicalTypes(
|
87 |
+
[request.apply(context)] + [v.apply(context) for v in environment])
|
88 |
+
else:
|
89 |
+
inTypes = canonicalTypes([request] + environment)
|
90 |
+
cacheKey = (tuple(inTypes), expression)
|
91 |
+
if cacheKey in self.likelihoodCache:
|
92 |
+
outTypes, l, u = self.likelihoodCache[cacheKey]
|
93 |
+
context, instantiatedTypes = instantiateTypes(
|
94 |
+
context, outTypes)
|
95 |
+
outRequest = instantiatedTypes[0]
|
96 |
+
outEnvironment = instantiatedTypes[1:]
|
97 |
+
# eprint("request:", request.apply(context), "environment:",
|
98 |
+
# [ v.apply(context) for v in environment ])
|
99 |
+
# eprint("will be unified with: out request:",outRequest,"out environment",outEnvironment)
|
100 |
+
if polymorphic:
|
101 |
+
context = context.unify(request, outRequest)
|
102 |
+
for v, vp in zip(environment, outEnvironment):
|
103 |
+
context = context.unify(v, vp)
|
104 |
+
return context, l, u
|
105 |
+
|
106 |
+
if request.isArrow():
|
107 |
+
if not isinstance(expression, Abstraction):
|
108 |
+
return (context, NEGATIVEINFINITY, Uses.empty)
|
109 |
+
return self._logLikelihood(context,
|
110 |
+
[request.arguments[0]] + environment,
|
111 |
+
request.arguments[1],
|
112 |
+
expression.body)
|
113 |
+
|
114 |
+
# Not a function type
|
115 |
+
|
116 |
+
# Construct and normalize the candidate productions
|
117 |
+
candidates = self.buildCandidates(context, environment, request)
|
118 |
+
|
119 |
+
# Consider each way of breaking the expression up into a
|
120 |
+
# function and arguments
|
121 |
+
totalLikelihood = NEGATIVEINFINITY
|
122 |
+
weightedUses = []
|
123 |
+
|
124 |
+
possibleVariables = float(int(any(isinstance(candidate, Index)
|
125 |
+
for _, _, _, candidate in candidates)))
|
126 |
+
possibleUses = {candidate: 1. for _, _, _, candidate in candidates
|
127 |
+
if not isinstance(candidate, Index)}
|
128 |
+
|
129 |
+
for f, xs in expression.applicationParses():
|
130 |
+
for candidateLikelihood, newContext, tp, production in candidates:
|
131 |
+
variableBindings = {}
|
132 |
+
# This is a variable in the environment
|
133 |
+
if production.isIndex:
|
134 |
+
if production != f:
|
135 |
+
continue
|
136 |
+
else:
|
137 |
+
try:
|
138 |
+
newContext, fragmentType, variableBindings = \
|
139 |
+
Matcher.match(newContext, production, f, len(xs))
|
140 |
+
# This is necessary because the types of the variable
|
141 |
+
# bindings and holes need to match up w/ request
|
142 |
+
fragmentTypeTemplate = request
|
143 |
+
for _ in xs:
|
144 |
+
newContext, newVariable = newContext.makeVariable()
|
145 |
+
fragmentTypeTemplate = arrow(
|
146 |
+
newVariable, fragmentTypeTemplate)
|
147 |
+
newContext = newContext.unify(
|
148 |
+
fragmentType, fragmentTypeTemplate)
|
149 |
+
# update the unified type
|
150 |
+
tp = fragmentType.apply(newContext)
|
151 |
+
except MatchFailure:
|
152 |
+
continue
|
153 |
+
|
154 |
+
argumentTypes = tp.functionArguments()
|
155 |
+
if len(xs) != len(argumentTypes):
|
156 |
+
# I think that this is some kind of bug. But I can't figure it out right now.
|
157 |
+
# As a hack, count this as though it were a failure
|
158 |
+
continue
|
159 |
+
#raise GrammarFailure('len(xs) != len(argumentTypes): tp={}, xs={}'.format(tp, xs))
|
160 |
+
|
161 |
+
thisLikelihood = candidateLikelihood
|
162 |
+
if isinstance(production, Index):
|
163 |
+
theseUses = Uses(possibleVariables=possibleVariables,
|
164 |
+
actualVariables=1.,
|
165 |
+
possibleUses=possibleUses.copy(),
|
166 |
+
actualUses={})
|
167 |
+
else:
|
168 |
+
theseUses = Uses(possibleVariables=possibleVariables,
|
169 |
+
actualVariables=0.,
|
170 |
+
possibleUses=possibleUses.copy(),
|
171 |
+
actualUses={production: 1.})
|
172 |
+
|
173 |
+
# Accumulate likelihood from free variables and holes and
|
174 |
+
# arguments
|
175 |
+
for freeType, freeExpression in chain(
|
176 |
+
variableBindings.values(), zip(argumentTypes, xs)):
|
177 |
+
freeType = freeType.apply(newContext)
|
178 |
+
newContext, expressionLikelihood, newUses = self._logLikelihood(
|
179 |
+
newContext, environment, freeType, freeExpression)
|
180 |
+
if expressionLikelihood is NEGATIVEINFINITY:
|
181 |
+
thisLikelihood = NEGATIVEINFINITY
|
182 |
+
break
|
183 |
+
|
184 |
+
thisLikelihood += expressionLikelihood
|
185 |
+
theseUses += newUses
|
186 |
+
|
187 |
+
if thisLikelihood is NEGATIVEINFINITY:
|
188 |
+
continue
|
189 |
+
|
190 |
+
weightedUses.append((thisLikelihood, theseUses))
|
191 |
+
totalLikelihood = lse(totalLikelihood, thisLikelihood)
|
192 |
+
|
193 |
+
# Any of these new context objects should be equally good
|
194 |
+
context = newContext
|
195 |
+
|
196 |
+
if totalLikelihood is NEGATIVEINFINITY:
|
197 |
+
return context, totalLikelihood, Uses.empty
|
198 |
+
assert weightedUses != []
|
199 |
+
|
200 |
+
allUses = Uses.join(totalLikelihood, *weightedUses)
|
201 |
+
|
202 |
+
# memoize result
|
203 |
+
if shouldDoCaching:
|
204 |
+
outTypes = [request.apply(context)] + \
|
205 |
+
[v.apply(context) for v in environment]
|
206 |
+
outTypes = canonicalTypes(outTypes)
|
207 |
+
self.likelihoodCache[cacheKey] = (
|
208 |
+
outTypes, totalLikelihood, allUses)
|
209 |
+
|
210 |
+
return context, totalLikelihood, allUses
|
211 |
+
|
212 |
+
def expectedUses(self, frontiers):
|
213 |
+
if len(list(frontiers)) == 0:
|
214 |
+
return Uses()
|
215 |
+
likelihoods = [[(l + entry.logLikelihood, u)
|
216 |
+
for entry in frontier
|
217 |
+
for l, u in [self.closedUses(frontier.task.request, entry.program)]]
|
218 |
+
for frontier in frontiers]
|
219 |
+
zs = (lse([l for l, _ in ls]) for ls in likelihoods)
|
220 |
+
return sum(math.exp(l - z) * u
|
221 |
+
for z, frontier in zip(zs, likelihoods)
|
222 |
+
for l, u in frontier)
|
223 |
+
|
224 |
+
def insideOutside(self, frontiers, pseudoCounts):
|
225 |
+
uses = self.expectedUses(frontiers)
|
226 |
+
return FragmentGrammar(log(uses.actualVariables +
|
227 |
+
pseudoCounts) -
|
228 |
+
log(max(uses.possibleVariables, 1.)), [(log(uses.actualUses.get(p, 0.) +
|
229 |
+
pseudoCounts) -
|
230 |
+
log(uses.possibleUses.get(p, 0.) +
|
231 |
+
pseudoCounts), t, p) for _, t, p in self.productions])
|
232 |
+
|
233 |
+
def jointFrontiersLikelihood(self, frontiers):
|
234 |
+
return sum(lse([entry.logLikelihood + self.logLikelihood(frontier.task.request, entry.program)
|
235 |
+
for entry in frontier])
|
236 |
+
for frontier in frontiers)
|
237 |
+
|
238 |
+
def jointFrontiersMDL(self, frontiers, CPUs=1):
|
239 |
+
return sum(
|
240 |
+
parallelMap(
|
241 |
+
CPUs,
|
242 |
+
lambda frontier: max(
|
243 |
+
entry.logLikelihood +
|
244 |
+
self.logLikelihood(
|
245 |
+
frontier.task.request,
|
246 |
+
entry.program) for entry in frontier),
|
247 |
+
frontiers))
|
248 |
+
|
249 |
+
def __len__(self): return len(self.productions)
|
250 |
+
|
251 |
+
@staticmethod
|
252 |
+
def fromGrammar(g):
|
253 |
+
return FragmentGrammar(g.logVariable, g.productions)
|
254 |
+
|
255 |
+
def toGrammar(self):
|
256 |
+
return Grammar(self.logVariable, [(l, q.infer(), q)
|
257 |
+
for l, t, p in self.productions
|
258 |
+
for q in [defragment(p)]])
|
259 |
+
|
260 |
+
@property
|
261 |
+
def primitives(self): return [p for _, _, p in self.productions]
|
262 |
+
|
263 |
+
@staticmethod
|
264 |
+
def uniform(productions):
|
265 |
+
return FragmentGrammar(0., [(0., p.infer(), p) for p in productions])
|
266 |
+
|
267 |
+
def normalize(self):
|
268 |
+
z = lse([l for l, t, p in self.productions] + [self.logVariable])
|
269 |
+
return FragmentGrammar(self.logVariable - z,
|
270 |
+
[(l - z, t, p) for l, t, p in self.productions])
|
271 |
+
|
272 |
+
def makeUniform(self):
|
273 |
+
return FragmentGrammar(0., [(0., p.infer(), p)
|
274 |
+
for _, _, p in self.productions])
|
275 |
+
|
276 |
+
def rescoreFrontier(self, frontier):
|
277 |
+
return Frontier([FrontierEntry(e.program,
|
278 |
+
logPrior=self.logLikelihood(frontier.task.request, e.program),
|
279 |
+
logLikelihood=e.logLikelihood)
|
280 |
+
for e in frontier],
|
281 |
+
frontier.task)
|
282 |
+
|
283 |
+
@staticmethod
|
284 |
+
def induceFromFrontiers(
|
285 |
+
g0,
|
286 |
+
frontiers,
|
287 |
+
_=None,
|
288 |
+
topK=1,
|
289 |
+
topk_use_only_likelihood=False,
|
290 |
+
pseudoCounts=1.0,
|
291 |
+
aic=1.0,
|
292 |
+
structurePenalty=0.001,
|
293 |
+
a=0,
|
294 |
+
CPUs=1):
|
295 |
+
_ = topk_use_only_likelihood # not used in python compressor
|
296 |
+
originalFrontiers = frontiers
|
297 |
+
frontiers = [frontier for frontier in frontiers if not frontier.empty]
|
298 |
+
eprint("Inducing a grammar from", len(frontiers), "frontiers")
|
299 |
+
|
300 |
+
bestGrammar = FragmentGrammar.fromGrammar(g0)
|
301 |
+
oldJoint = bestGrammar.jointFrontiersMDL(frontiers, CPUs=1)
|
302 |
+
|
303 |
+
# "restricted frontiers" only contain the top K according to the best grammar
|
304 |
+
def restrictFrontiers():
|
305 |
+
return parallelMap(
|
306 |
+
CPUs,
|
307 |
+
lambda f: bestGrammar.rescoreFrontier(f).topK(topK),
|
308 |
+
frontiers)
|
309 |
+
restrictedFrontiers = []
|
310 |
+
|
311 |
+
def grammarScore(g):
|
312 |
+
g = g.makeUniform().insideOutside(restrictedFrontiers, pseudoCounts)
|
313 |
+
likelihood = g.jointFrontiersMDL(restrictedFrontiers)
|
314 |
+
structure = sum(primitiveSize(p) for p in g.primitives)
|
315 |
+
score = likelihood - aic * len(g) - structurePenalty * structure
|
316 |
+
g.clearCache()
|
317 |
+
if invalid(score):
|
318 |
+
# FIXME: This should never occur but it does anyway
|
319 |
+
score = float('-inf')
|
320 |
+
return score, g
|
321 |
+
|
322 |
+
if aic is not POSITIVEINFINITY:
|
323 |
+
restrictedFrontiers = restrictFrontiers()
|
324 |
+
bestScore, _ = grammarScore(bestGrammar)
|
325 |
+
eprint("Starting score", bestScore)
|
326 |
+
while True:
|
327 |
+
restrictedFrontiers = restrictFrontiers()
|
328 |
+
fragments = [f
|
329 |
+
for f in proposeFragmentsFromFrontiers(restrictedFrontiers, a, CPUs=CPUs)
|
330 |
+
if f not in bestGrammar.primitives
|
331 |
+
and defragment(f) not in bestGrammar.primitives]
|
332 |
+
eprint("Proposed %d fragments." % len(fragments))
|
333 |
+
|
334 |
+
candidateGrammars = [
|
335 |
+
FragmentGrammar.uniform(
|
336 |
+
bestGrammar.primitives +
|
337 |
+
[fragment]) for fragment in fragments]
|
338 |
+
if not candidateGrammars:
|
339 |
+
break
|
340 |
+
|
341 |
+
scoredFragments = parallelMap(CPUs, grammarScore, candidateGrammars,
|
342 |
+
# Each process handles up to 100
|
343 |
+
# grammars at a time, a "job"
|
344 |
+
chunksize=max(
|
345 |
+
1, min(len(candidateGrammars) // CPUs, 100)),
|
346 |
+
# maxTasks: Maximum number of jobs allocated to a process
|
347 |
+
# This means that after evaluating this*chunk many grammars,
|
348 |
+
# we killed the process, freeing up its memory.
|
349 |
+
# In exchange we pay the cost of spawning a new process.
|
350 |
+
# We should play with this number,
|
351 |
+
# figuring out how big we can make it without
|
352 |
+
# running out of memory.
|
353 |
+
maxtasksperchild=5)
|
354 |
+
newScore, newGrammar = max(scoredFragments, key=lambda sg: sg[0])
|
355 |
+
|
356 |
+
if newScore <= bestScore:
|
357 |
+
break
|
358 |
+
dS = newScore - bestScore
|
359 |
+
bestScore, bestGrammar = newScore, newGrammar
|
360 |
+
newPrimitiveLikelihood, newType, newPrimitive = bestGrammar.productions[-1]
|
361 |
+
expectedUses = bestGrammar.expectedUses(
|
362 |
+
restrictedFrontiers).actualUses.get(newPrimitive, 0)
|
363 |
+
eprint(
|
364 |
+
"New primitive of type %s\t%s\t\n(score = %f; dScore = %f; <uses> = %f)" %
|
365 |
+
(newType, newPrimitive, newScore, dS, expectedUses))
|
366 |
+
|
367 |
+
# Rewrite the frontiers in terms of the new fragment
|
368 |
+
concretePrimitive = defragment(newPrimitive)
|
369 |
+
bestGrammar.productions[-1] = (newPrimitiveLikelihood,
|
370 |
+
concretePrimitive.tp,
|
371 |
+
concretePrimitive)
|
372 |
+
frontiers = parallelMap(
|
373 |
+
CPUs, lambda frontier: bestGrammar.rescoreFrontier(
|
374 |
+
RewriteFragments.rewriteFrontier(
|
375 |
+
frontier, newPrimitive)), frontiers)
|
376 |
+
eprint(
|
377 |
+
"\t(<uses> in rewritten frontiers: %f)" %
|
378 |
+
(bestGrammar.expectedUses(frontiers).actualUses[concretePrimitive]))
|
379 |
+
else:
|
380 |
+
eprint("Skipping fragment proposals")
|
381 |
+
|
382 |
+
if False:
|
383 |
+
# Reestimate the parameters using the entire frontiers
|
384 |
+
bestGrammar = bestGrammar.makeUniform().insideOutside(frontiers, pseudoCounts)
|
385 |
+
elif True:
|
386 |
+
# Reestimate the parameters using the best programs
|
387 |
+
restrictedFrontiers = restrictFrontiers()
|
388 |
+
bestGrammar = bestGrammar.makeUniform().insideOutside(
|
389 |
+
restrictedFrontiers, pseudoCounts)
|
390 |
+
else:
|
391 |
+
# Use parameters that were found during search
|
392 |
+
pass
|
393 |
+
|
394 |
+
eprint("Old joint = %f\tNew joint = %f\n" %
|
395 |
+
(oldJoint, bestGrammar.jointFrontiersMDL(frontiers, CPUs=CPUs)))
|
396 |
+
# Return all of the frontiers, which have now been rewritten to use the
|
397 |
+
# new fragments
|
398 |
+
frontiers = {f.task: f for f in frontiers}
|
399 |
+
frontiers = [frontiers.get(f.task, f)
|
400 |
+
for f in originalFrontiers]
|
401 |
+
|
402 |
+
productionUses = bestGrammar.expectedUses(
|
403 |
+
[f for f in frontiers if not f.empty]).actualUses
|
404 |
+
productionUses = {
|
405 |
+
p: productionUses.get(
|
406 |
+
p, 0.) for p in bestGrammar.primitives}
|
407 |
+
possibleUses = bestGrammar.expectedUses(
|
408 |
+
[f for f in frontiers if not f.empty]).possibleUses
|
409 |
+
possibleUses = {
|
410 |
+
p: possibleUses.get(
|
411 |
+
p, 0.) for p in bestGrammar.primitives}
|
412 |
+
|
413 |
+
for p in bestGrammar.primitives:
|
414 |
+
eprint("%f / %f\t%s" % (productionUses[p],
|
415 |
+
possibleUses[p],
|
416 |
+
p))
|
417 |
+
|
418 |
+
bestGrammar.clearCache()
|
419 |
+
|
420 |
+
grammar = bestGrammar.toGrammar()
|
421 |
+
|
422 |
+
if False and \
|
423 |
+
any(productionUses.get(p, 0) < 0.5 for p in grammar.primitives if p.isInvented):
|
424 |
+
uselessProductions = [ p for p in grammar.primitives
|
425 |
+
if p.isInvented and productionUses.get(p, 0) < 0.5]
|
426 |
+
eprint("The following invented primitives are no longer needed, removing them...")
|
427 |
+
eprint("\t" + "\t\n".join(map(str, uselessProductions)))
|
428 |
+
grammar = grammar.removeProductions(uselessProductions)
|
429 |
+
|
430 |
+
return grammar, frontiers
|
dreamcoder/fragmentUtilities.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.type import *
|
2 |
+
from dreamcoder.program import *
|
3 |
+
from dreamcoder.frontier import *
|
4 |
+
|
5 |
+
from collections import Counter
|
6 |
+
|
7 |
+
|
8 |
+
class MatchFailure(Exception):
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
class Matcher(object):
|
13 |
+
def __init__(self, context):
|
14 |
+
self.context = context
|
15 |
+
self.variableBindings = {}
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def match(context, fragment, expression, numberOfArguments):
|
19 |
+
if not mightMatch(fragment, expression):
|
20 |
+
raise MatchFailure()
|
21 |
+
m = Matcher(context)
|
22 |
+
tp = fragment.visit(m, expression, [], numberOfArguments)
|
23 |
+
return m.context, tp, m.variableBindings
|
24 |
+
|
25 |
+
def application(
|
26 |
+
self,
|
27 |
+
fragment,
|
28 |
+
expression,
|
29 |
+
environment,
|
30 |
+
numberOfArguments):
|
31 |
+
'''returns tp of fragment.'''
|
32 |
+
if not isinstance(expression, Application):
|
33 |
+
raise MatchFailure()
|
34 |
+
|
35 |
+
ft = fragment.f.visit(
|
36 |
+
self,
|
37 |
+
expression.f,
|
38 |
+
environment,
|
39 |
+
numberOfArguments + 1)
|
40 |
+
xt = fragment.x.visit(self, expression.x, environment, 0)
|
41 |
+
|
42 |
+
self.context, returnType = self.context.makeVariable()
|
43 |
+
try:
|
44 |
+
self.context = self.context.unify(ft, arrow(xt, returnType))
|
45 |
+
except UnificationFailure:
|
46 |
+
raise MatchFailure()
|
47 |
+
|
48 |
+
return returnType.apply(self.context)
|
49 |
+
|
50 |
+
def index(self, fragment, expression, environment, numberOfArguments):
|
51 |
+
# This is a bound variable
|
52 |
+
surroundingAbstractions = len(environment)
|
53 |
+
if fragment.bound(surroundingAbstractions):
|
54 |
+
if expression == fragment:
|
55 |
+
return environment[fragment.i].apply(self.context)
|
56 |
+
else:
|
57 |
+
raise MatchFailure()
|
58 |
+
|
59 |
+
# This is a free variable
|
60 |
+
i = fragment.i - surroundingAbstractions
|
61 |
+
|
62 |
+
# Make sure that it doesn't refer to anything bound by a
|
63 |
+
# lambda in the fragment. Otherwise it cannot be safely lifted
|
64 |
+
# out of the fragment and preserve semantics
|
65 |
+
for fv in expression.freeVariables():
|
66 |
+
if fv < len(environment):
|
67 |
+
raise MatchFailure()
|
68 |
+
|
69 |
+
# The value is going to be lifted out of the fragment
|
70 |
+
try:
|
71 |
+
expression = expression.shift(-surroundingAbstractions)
|
72 |
+
except ShiftFailure:
|
73 |
+
raise MatchFailure()
|
74 |
+
|
75 |
+
# Wrap it in the appropriate number of lambda expressions & applications
|
76 |
+
# This is because everything has to be in eta-longform
|
77 |
+
if numberOfArguments > 0:
|
78 |
+
expression = expression.shift(numberOfArguments)
|
79 |
+
for j in reversed(range(numberOfArguments)):
|
80 |
+
expression = Application(expression, Index(j))
|
81 |
+
for _ in range(numberOfArguments):
|
82 |
+
expression = Abstraction(expression)
|
83 |
+
|
84 |
+
# Added to the bindings
|
85 |
+
if i in self.variableBindings:
|
86 |
+
(tp, binding) = self.variableBindings[i]
|
87 |
+
if binding != expression:
|
88 |
+
raise MatchFailure()
|
89 |
+
else:
|
90 |
+
self.context, tp = self.context.makeVariable()
|
91 |
+
self.variableBindings[i] = (tp, expression)
|
92 |
+
return tp
|
93 |
+
|
94 |
+
def abstraction(
|
95 |
+
self,
|
96 |
+
fragment,
|
97 |
+
expression,
|
98 |
+
environment,
|
99 |
+
numberOfArguments):
|
100 |
+
if not isinstance(expression, Abstraction):
|
101 |
+
raise MatchFailure()
|
102 |
+
|
103 |
+
self.context, argumentType = self.context.makeVariable()
|
104 |
+
returnType = fragment.body.visit(
|
105 |
+
self, expression.body, [argumentType] + environment, 0)
|
106 |
+
|
107 |
+
return arrow(argumentType, returnType)
|
108 |
+
|
109 |
+
def primitive(self, fragment, expression, environment, numberOfArguments):
|
110 |
+
if fragment != expression:
|
111 |
+
raise MatchFailure()
|
112 |
+
self.context, tp = fragment.tp.instantiate(self.context)
|
113 |
+
return tp
|
114 |
+
|
115 |
+
def invented(self, fragment, expression, environment, numberOfArguments):
|
116 |
+
if fragment != expression:
|
117 |
+
raise MatchFailure()
|
118 |
+
self.context, tp = fragment.tp.instantiate(self.context)
|
119 |
+
return tp
|
120 |
+
|
121 |
+
def fragmentVariable(
|
122 |
+
self,
|
123 |
+
fragment,
|
124 |
+
expression,
|
125 |
+
environment,
|
126 |
+
numberOfArguments):
|
127 |
+
raise Exception(
|
128 |
+
'Deprecated: matching against fragment variables. Convert fragment to canonical form to get rid of fragment variables.')
|
129 |
+
|
130 |
+
|
131 |
+
def mightMatch(f, e, d=0):
|
132 |
+
'''Checks whether fragment f might be able to match against expression e'''
|
133 |
+
if f.isIndex:
|
134 |
+
if f.bound(d):
|
135 |
+
return f == e
|
136 |
+
return True
|
137 |
+
if f.isPrimitive or f.isInvented:
|
138 |
+
return f == e
|
139 |
+
if f.isAbstraction:
|
140 |
+
return e.isAbstraction and mightMatch(f.body, e.body, d + 1)
|
141 |
+
if f.isApplication:
|
142 |
+
return e.isApplication and mightMatch(
|
143 |
+
f.x, e.x, d) and mightMatch(
|
144 |
+
f.f, e.f, d)
|
145 |
+
assert False
|
146 |
+
|
147 |
+
|
148 |
+
def canonicalFragment(expression):
|
149 |
+
'''
|
150 |
+
Puts a fragment into a canonical form:
|
151 |
+
1. removes all FragmentVariable's
|
152 |
+
2. renames all free variables based on depth first traversal
|
153 |
+
'''
|
154 |
+
return expression.visit(CanonicalVisitor(), 0)
|
155 |
+
|
156 |
+
|
157 |
+
class CanonicalVisitor(object):
|
158 |
+
def __init__(self):
|
159 |
+
self.numberOfAbstractions = 0
|
160 |
+
self.mapping = {}
|
161 |
+
|
162 |
+
def fragmentVariable(self, e, d):
|
163 |
+
self.numberOfAbstractions += 1
|
164 |
+
return Index(self.numberOfAbstractions + d - 1)
|
165 |
+
|
166 |
+
def primitive(self, e, d): return e
|
167 |
+
|
168 |
+
def invented(self, e, d): return e
|
169 |
+
|
170 |
+
def application(self, e, d):
|
171 |
+
return Application(e.f.visit(self, d), e.x.visit(self, d))
|
172 |
+
|
173 |
+
def abstraction(self, e, d):
|
174 |
+
return Abstraction(e.body.visit(self, d + 1))
|
175 |
+
|
176 |
+
def index(self, e, d):
|
177 |
+
if e.bound(d):
|
178 |
+
return e
|
179 |
+
i = e.i - d
|
180 |
+
if i in self.mapping:
|
181 |
+
return Index(d + self.mapping[i])
|
182 |
+
self.mapping[i] = self.numberOfAbstractions
|
183 |
+
self.numberOfAbstractions += 1
|
184 |
+
return Index(self.numberOfAbstractions - 1 + d)
|
185 |
+
|
186 |
+
|
187 |
+
def fragmentSize(f, boundVariableCost=0.1, freeVariableCost=0.01):
|
188 |
+
freeVariables = 0
|
189 |
+
leaves = 0
|
190 |
+
boundVariables = 0
|
191 |
+
for surroundingAbstractions, e in f.walk():
|
192 |
+
if isinstance(e, (Primitive, Invented)):
|
193 |
+
leaves += 1
|
194 |
+
if isinstance(e, Index):
|
195 |
+
if e.bound(surroundingAbstractions):
|
196 |
+
boundVariables += 1
|
197 |
+
else:
|
198 |
+
freeVariables += 1
|
199 |
+
assert not isinstance(e, FragmentVariable)
|
200 |
+
return leaves + boundVariableCost * \
|
201 |
+
boundVariables + freeVariableCost * freeVariables
|
202 |
+
|
203 |
+
|
204 |
+
def primitiveSize(e):
|
205 |
+
if e.isInvented:
|
206 |
+
e = e.body
|
207 |
+
return fragmentSize(e)
|
208 |
+
|
209 |
+
|
210 |
+
def defragment(expression):
|
211 |
+
'''Converts a fragment into an invented primitive'''
|
212 |
+
if isinstance(expression, (Primitive, Invented)):
|
213 |
+
return expression
|
214 |
+
|
215 |
+
expression = canonicalFragment(expression)
|
216 |
+
|
217 |
+
for _ in range(expression.numberOfFreeVariables):
|
218 |
+
expression = Abstraction(expression)
|
219 |
+
|
220 |
+
return Invented(expression)
|
221 |
+
|
222 |
+
|
223 |
+
class RewriteFragments(object):
|
224 |
+
def __init__(self, fragment):
|
225 |
+
self.fragment = fragment
|
226 |
+
self.concrete = defragment(fragment)
|
227 |
+
|
228 |
+
def tryRewrite(self, e, numberOfArguments):
|
229 |
+
try:
|
230 |
+
context, t, bindings = Matcher.match(
|
231 |
+
Context.EMPTY, self.fragment, e, numberOfArguments)
|
232 |
+
except MatchFailure:
|
233 |
+
return None
|
234 |
+
|
235 |
+
assert frozenset(bindings.keys()) == frozenset(range(len(bindings))),\
|
236 |
+
"Perhaps the fragment is not in canonical form?"
|
237 |
+
e = self.concrete
|
238 |
+
for j in range(len(bindings) - 1, -1, -1):
|
239 |
+
_, b = bindings[j]
|
240 |
+
e = Application(e, b)
|
241 |
+
return e
|
242 |
+
|
243 |
+
def application(self, e, numberOfArguments):
|
244 |
+
e = Application(e.f.visit(self, numberOfArguments + 1),
|
245 |
+
e.x.visit(self, 0))
|
246 |
+
return self.tryRewrite(e, numberOfArguments) or e
|
247 |
+
|
248 |
+
def index(self, e, numberOfArguments): return e
|
249 |
+
|
250 |
+
def invented(self, e, numberOfArguments): return e
|
251 |
+
|
252 |
+
def primitive(self, e, numberOfArguments): return e
|
253 |
+
|
254 |
+
def abstraction(self, e, numberOfArguments):
|
255 |
+
e = Abstraction(e.body.visit(self, 0))
|
256 |
+
return self.tryRewrite(e, numberOfArguments) or e
|
257 |
+
|
258 |
+
def rewrite(self, e): return e.visit(self, 0)
|
259 |
+
|
260 |
+
@staticmethod
|
261 |
+
def rewriteFrontier(frontier, fragment):
|
262 |
+
worker = RewriteFragments(fragment)
|
263 |
+
return Frontier([FrontierEntry(program=worker.rewrite(e.program),
|
264 |
+
logLikelihood=e.logLikelihood,
|
265 |
+
logPrior=e.logPrior,
|
266 |
+
logPosterior=e.logPosterior)
|
267 |
+
for e in frontier],
|
268 |
+
task=frontier.task)
|
269 |
+
|
270 |
+
|
271 |
+
def proposeFragmentsFromFragment(f):
|
272 |
+
'''Abstracts out repeated structure within a single fragment'''
|
273 |
+
yield f
|
274 |
+
freeVariables = f.numberOfFreeVariables
|
275 |
+
closedSubtrees = Counter(
|
276 |
+
subtree for _,
|
277 |
+
subtree in f.walk() if not isinstance(
|
278 |
+
subtree,
|
279 |
+
Index) and subtree.closed)
|
280 |
+
del closedSubtrees[f]
|
281 |
+
for subtree, freq in closedSubtrees.items():
|
282 |
+
if freq < 2:
|
283 |
+
continue
|
284 |
+
yield canonicalFragment(f.substitute(subtree, Index(freeVariables)))
|
285 |
+
|
286 |
+
|
287 |
+
def nontrivial(f):
|
288 |
+
if not isinstance(f, Application):
|
289 |
+
return False
|
290 |
+
# Curry
|
291 |
+
if isinstance(f.x, FragmentVariable):
|
292 |
+
return False
|
293 |
+
if isinstance(f.x, Index):
|
294 |
+
# Make sure that the index is used somewhere else
|
295 |
+
if not any(
|
296 |
+
isinstance(
|
297 |
+
child,
|
298 |
+
Index) and child.i -
|
299 |
+
surroundingAbstractions == f.x.i for surroundingAbstractions,
|
300 |
+
child in f.f.walk()):
|
301 |
+
return False
|
302 |
+
|
303 |
+
numberOfHoles = 0
|
304 |
+
numberOfVariables = 0
|
305 |
+
numberOfPrimitives = 0
|
306 |
+
for surroundingAbstractions, child in f.walk():
|
307 |
+
if isinstance(child, (Primitive, Invented)):
|
308 |
+
numberOfPrimitives += 1
|
309 |
+
if isinstance(child, FragmentVariable):
|
310 |
+
numberOfHoles += 1
|
311 |
+
if isinstance(child, Index) and child.free(surroundingAbstractions):
|
312 |
+
numberOfVariables += 1
|
313 |
+
#eprint("Fragment %s has %d calls and %d variables and %d primitives"%(f,numberOfHoles,numberOfVariables,numberOfPrimitives))
|
314 |
+
|
315 |
+
return numberOfPrimitives + 0.5 * \
|
316 |
+
(numberOfHoles + numberOfVariables) > 1.5 and numberOfPrimitives >= 1
|
317 |
+
|
318 |
+
|
319 |
+
def violatesLaziness(fragment):
|
320 |
+
"""
|
321 |
+
conditionals are lazy on the second and third arguments. this
|
322 |
+
invariant must be maintained by learned fragments.
|
323 |
+
"""
|
324 |
+
for surroundingAbstractions, child in fragment.walkUncurried():
|
325 |
+
if not child.isApplication:
|
326 |
+
continue
|
327 |
+
f, xs = child.applicationParse()
|
328 |
+
if not (f.isPrimitive and f.name == "if"):
|
329 |
+
continue
|
330 |
+
|
331 |
+
# curried conditionals always violate laziness
|
332 |
+
if len(xs) != 3:
|
333 |
+
return True
|
334 |
+
|
335 |
+
# yes/no branches
|
336 |
+
y = xs[1]
|
337 |
+
n = xs[2]
|
338 |
+
|
339 |
+
return \
|
340 |
+
any(yc.isIndex and yc.i >= yd
|
341 |
+
for yd, yc in y.walk(surroundingAbstractions)) or \
|
342 |
+
any(nc.isIndex and nc.i >= nd
|
343 |
+
for nd, nc in n.walk(surroundingAbstractions))
|
344 |
+
|
345 |
+
return False
|
346 |
+
|
347 |
+
|
348 |
+
def proposeFragmentsFromProgram(p, arity):
|
349 |
+
|
350 |
+
def fragment(expression, a, toplevel=True):
|
351 |
+
"""Generates fragments that unify with expression"""
|
352 |
+
|
353 |
+
if a == 1:
|
354 |
+
yield FragmentVariable.single
|
355 |
+
if a == 0:
|
356 |
+
yield expression
|
357 |
+
return
|
358 |
+
|
359 |
+
if isinstance(expression, Abstraction):
|
360 |
+
# Symmetry breaking: (\x \y \z ... f(x,y,z,...)) defragments to be
|
361 |
+
# the same as f(x,y,z,...)
|
362 |
+
if not toplevel:
|
363 |
+
for b in fragment(expression.body, a, toplevel=False):
|
364 |
+
yield Abstraction(b)
|
365 |
+
elif isinstance(expression, Application):
|
366 |
+
for fa in range(a + 1):
|
367 |
+
for f in fragment(expression.f, fa, toplevel=False):
|
368 |
+
for x in fragment(expression.x, a - fa, toplevel=False):
|
369 |
+
yield Application(f, x)
|
370 |
+
else:
|
371 |
+
assert isinstance(expression, (Invented, Primitive, Index))
|
372 |
+
|
373 |
+
def fragments(expression, a):
|
374 |
+
"""Generates fragments that unify with subexpressions of expression"""
|
375 |
+
|
376 |
+
yield from fragment(expression, a)
|
377 |
+
if isinstance(expression, Application):
|
378 |
+
curry = True
|
379 |
+
if curry:
|
380 |
+
yield from fragments(expression.f, a)
|
381 |
+
yield from fragments(expression.x, a)
|
382 |
+
else:
|
383 |
+
# Pretend that it is not curried
|
384 |
+
function, arguments = expression.applicationParse()
|
385 |
+
yield from fragments(function, a)
|
386 |
+
for argument in arguments:
|
387 |
+
yield from fragments(argument, a)
|
388 |
+
elif isinstance(expression, Abstraction):
|
389 |
+
yield from fragments(expression.body, a)
|
390 |
+
else:
|
391 |
+
assert isinstance(expression, (Invented, Primitive, Index))
|
392 |
+
|
393 |
+
return {canonicalFragment(f) for b in range(arity + 1)
|
394 |
+
for f in fragments(p, b) if nontrivial(f)}
|
395 |
+
|
396 |
+
|
397 |
+
def proposeFragmentsFromFrontiers(frontiers, a, CPUs=1):
|
398 |
+
fragmentsFromEachFrontier = parallelMap(
|
399 |
+
CPUs, lambda frontier: {
|
400 |
+
fp for entry in frontier.entries for f in proposeFragmentsFromProgram(
|
401 |
+
entry.program, a) for fp in proposeFragmentsFromFragment(f)}, frontiers)
|
402 |
+
allFragments = Counter(f for frontierFragments in fragmentsFromEachFrontier
|
403 |
+
for f in frontierFragments)
|
404 |
+
return [fragment for fragment, frequency in allFragments.items()
|
405 |
+
if frequency >= 2 and fragment.wellTyped() and nontrivial(fragment)]
|
dreamcoder/frontier.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.utilities import *
|
2 |
+
from dreamcoder.program import *
|
3 |
+
from dreamcoder.task import Task
|
4 |
+
|
5 |
+
|
6 |
+
class FrontierEntry(object):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
program,
|
10 |
+
_=None,
|
11 |
+
logPrior=None,
|
12 |
+
logLikelihood=None,
|
13 |
+
logPosterior=None):
|
14 |
+
self.logPosterior = logPrior + logLikelihood if logPosterior is None else logPosterior
|
15 |
+
self.program = program
|
16 |
+
self.logPrior = logPrior
|
17 |
+
self.logLikelihood = logLikelihood
|
18 |
+
|
19 |
+
def __repr__(self):
|
20 |
+
return "FrontierEntry(program={self.program}, logPrior={self.logPrior}, logLikelihood={self.logLikelihood}".format(
|
21 |
+
self=self)
|
22 |
+
|
23 |
+
def strip_primitive_values(self):
|
24 |
+
return FrontierEntry(program=strip_primitive_values(self.program),
|
25 |
+
logPrior=self.logPrior,
|
26 |
+
logPosterior=self.logPosterior,
|
27 |
+
logLikelihood=self.logLikelihood)
|
28 |
+
def unstrip_primitive_values(self):
|
29 |
+
return FrontierEntry(program=unstrip_primitive_values(self.program),
|
30 |
+
logPrior=self.logPrior,
|
31 |
+
logPosterior=self.logPosterior,
|
32 |
+
logLikelihood=self.logLikelihood)
|
33 |
+
|
34 |
+
|
35 |
+
class Frontier(object):
|
36 |
+
def __init__(self, frontier, task):
|
37 |
+
self.entries = frontier
|
38 |
+
self.task = task
|
39 |
+
|
40 |
+
def __repr__(
|
41 |
+
self): return "Frontier(entries={self.entries}, task={self.task})".format(self=self)
|
42 |
+
|
43 |
+
def __iter__(self): return iter(self.entries)
|
44 |
+
|
45 |
+
def __len__(self): return len(self.entries)
|
46 |
+
|
47 |
+
def json(self):
|
48 |
+
return {"request": self.task.request.json(),
|
49 |
+
"task": str(self.task),
|
50 |
+
"programs": [{"program": str(e.program),
|
51 |
+
"logLikelihood": e.logLikelihood}
|
52 |
+
for e in self ]}
|
53 |
+
|
54 |
+
def strip_primitive_values(self):
|
55 |
+
return Frontier([e.strip_primitive_values() for e in self.entries ],
|
56 |
+
self.task)
|
57 |
+
def unstrip_primitive_values(self):
|
58 |
+
return Frontier([e.unstrip_primitive_values() for e in self.entries ],
|
59 |
+
self.task)
|
60 |
+
|
61 |
+
DUMMYFRONTIERCOUNTER = 0
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def dummy(program, logLikelihood=0., logPrior=0., tp=None):
|
65 |
+
"""Creates a dummy frontier containing just this program"""
|
66 |
+
if not tp:
|
67 |
+
tp = program.infer().negateVariables()
|
68 |
+
|
69 |
+
t = Task(
|
70 |
+
"<dummy %d: %s>" %
|
71 |
+
(Frontier.DUMMYFRONTIERCOUNTER,
|
72 |
+
str(program)),
|
73 |
+
tp,
|
74 |
+
[])
|
75 |
+
f = Frontier([FrontierEntry(program=program,
|
76 |
+
logLikelihood=logLikelihood,
|
77 |
+
logPrior=logPrior)],
|
78 |
+
task=t)
|
79 |
+
Frontier.DUMMYFRONTIERCOUNTER += 1
|
80 |
+
return f
|
81 |
+
|
82 |
+
def marginalLikelihood(self):
|
83 |
+
return lse([e.logPrior + e.logLikelihood for e in self])
|
84 |
+
|
85 |
+
def temperature(self,T):
|
86 |
+
"""Divides prior by T"""
|
87 |
+
return Frontier([ FrontierEntry(program=e.program,
|
88 |
+
logPrior=e.logPrior/T,
|
89 |
+
logLikelihood=e.logLikelihood)
|
90 |
+
for e in self],
|
91 |
+
task=self.task)
|
92 |
+
|
93 |
+
|
94 |
+
def normalize(self):
|
95 |
+
z = self.marginalLikelihood()
|
96 |
+
newEntries = [
|
97 |
+
FrontierEntry(
|
98 |
+
program=e.program,
|
99 |
+
logPrior=e.logPrior,
|
100 |
+
logLikelihood=e.logLikelihood,
|
101 |
+
logPosterior=e.logPrior +
|
102 |
+
e.logLikelihood -
|
103 |
+
z) for e in self]
|
104 |
+
newEntries.sort(key=lambda e: e.logPosterior, reverse=True)
|
105 |
+
return Frontier(newEntries,
|
106 |
+
self.task)
|
107 |
+
|
108 |
+
def expectedProductionUses(self, g):
|
109 |
+
"""Returns a vector of the expected number of times each production was used"""
|
110 |
+
import numpy as np
|
111 |
+
|
112 |
+
this = g.rescoreFrontier(self).normalize()
|
113 |
+
ps = list(sorted(g.primitives, key=str))
|
114 |
+
features = np.zeros(len(ps))
|
115 |
+
|
116 |
+
for j, p in enumerate(ps):
|
117 |
+
for e in this:
|
118 |
+
w = math.exp(e.logPosterior)
|
119 |
+
features[j] += w * sum(child == p
|
120 |
+
for _, child in e.program.walk() )
|
121 |
+
if not p.isInvented: features[j] *= 0.3
|
122 |
+
return features
|
123 |
+
|
124 |
+
|
125 |
+
def removeZeroLikelihood(self):
|
126 |
+
self.entries = [
|
127 |
+
e for e in self.entries if e.logLikelihood != float('-inf')]
|
128 |
+
return self
|
129 |
+
|
130 |
+
def topK(self, k):
|
131 |
+
if k == 0: return Frontier([], self.task)
|
132 |
+
if k < 0: return self
|
133 |
+
newEntries = sorted(self.entries,
|
134 |
+
key=lambda e: (-e.logPosterior, str(e.program)))
|
135 |
+
return Frontier(newEntries[:k], self.task)
|
136 |
+
|
137 |
+
def sample(self):
|
138 |
+
"""Samples an entry from a frontier"""
|
139 |
+
return sampleLogDistribution([(e.logLikelihood + e.logPrior, e)
|
140 |
+
for e in self])
|
141 |
+
|
142 |
+
@property
|
143 |
+
def bestPosterior(self):
|
144 |
+
return min(self.entries,
|
145 |
+
key=lambda e: (-e.logPosterior, str(e.program)))
|
146 |
+
|
147 |
+
def replaceWithSupervised(self, g):
|
148 |
+
assert self.task.supervision is not None
|
149 |
+
return g.rescoreFrontier(Frontier([FrontierEntry(self.task.supervision,
|
150 |
+
logLikelihood=0., logPrior=0.)],
|
151 |
+
task=self.task))
|
152 |
+
|
153 |
+
@property
|
154 |
+
def bestll(self):
|
155 |
+
best = max(self.entries,
|
156 |
+
key=lambda e: e.logLikelihood)
|
157 |
+
return best.logLikelihood
|
158 |
+
|
159 |
+
|
160 |
+
@property
|
161 |
+
def empty(self): return self.entries == []
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def makeEmpty(task):
|
165 |
+
return Frontier([], task=task)
|
166 |
+
|
167 |
+
def summarize(self):
|
168 |
+
if self.empty:
|
169 |
+
return "MISS " + self.task.name
|
170 |
+
best = self.bestPosterior
|
171 |
+
return "HIT %s w/ %s ; log prior = %f ; log likelihood = %f" % (
|
172 |
+
self.task.name, best.program, best.logPrior, best.logLikelihood)
|
173 |
+
|
174 |
+
def summarizeFull(self):
|
175 |
+
if self.empty:
|
176 |
+
return "MISS " + self.task.name
|
177 |
+
return "\n".join([self.task.name] +
|
178 |
+
["%f\t%s" % (e.logPosterior, e.program)
|
179 |
+
for e in self.normalize()])
|
180 |
+
|
181 |
+
@staticmethod
|
182 |
+
def describe(frontiers):
|
183 |
+
numberOfHits = sum(not f.empty for f in frontiers)
|
184 |
+
if numberOfHits > 0:
|
185 |
+
averageLikelihood = sum(
|
186 |
+
f.bestPosterior.logPrior for f in frontiers if not f.empty) / numberOfHits
|
187 |
+
else:
|
188 |
+
averageLikelihood = 0
|
189 |
+
return "\n".join([f.summarize() for f in frontiers] +
|
190 |
+
["Hits %d/%d tasks" % (numberOfHits, len(frontiers))] +
|
191 |
+
["Average description length of a program solving a task: %f nats" % (-averageLikelihood)])
|
192 |
+
|
193 |
+
def combine(self, other, tolerance=0.01):
|
194 |
+
'''Takes the union of the programs in each of the frontiers'''
|
195 |
+
assert self.task == other.task
|
196 |
+
|
197 |
+
foundDifference = False
|
198 |
+
|
199 |
+
x = {e.program: e for e in self}
|
200 |
+
y = {e.program: e for e in other}
|
201 |
+
programs = set(x.keys()) | set(y.keys())
|
202 |
+
union = []
|
203 |
+
for p in programs:
|
204 |
+
if p in x:
|
205 |
+
e1 = x[p]
|
206 |
+
if p in y:
|
207 |
+
e2 = y[p]
|
208 |
+
if abs(e1.logPrior - e2.logPrior) > tolerance:
|
209 |
+
eprint(
|
210 |
+
"WARNING: Log priors differed during frontier combining: %f vs %f" %
|
211 |
+
(e1.logPrior, e2.logPrior))
|
212 |
+
eprint("WARNING: \tThe program is", p)
|
213 |
+
eprint()
|
214 |
+
if abs(e1.logLikelihood - e2.logLikelihood) > tolerance:
|
215 |
+
foundDifference = True
|
216 |
+
eprint(
|
217 |
+
"WARNING: Log likelihoods deferred for %s: %f & %f" %
|
218 |
+
(p, e1.logLikelihood, e2.logLikelihood))
|
219 |
+
if hasattr(self.task, 'BIC'):
|
220 |
+
eprint("\t%d examples, BIC=%f, parameterPenalty=%f, n parameters=%d, correct likelihood=%f" %
|
221 |
+
(len(self.task.examples),
|
222 |
+
self.task.BIC,
|
223 |
+
self.task.BIC * math.log(len(self.task.examples)),
|
224 |
+
substringOccurrences("REAL", str(p)),
|
225 |
+
substringOccurrences("REAL", str(p)) * self.task.BIC * math.log(len(self.task.examples))))
|
226 |
+
e1.logLikelihood = - \
|
227 |
+
substringOccurrences("REAL", str(p)) * self.task.BIC * math.log(len(self.task.examples))
|
228 |
+
e2.logLikelihood = e1.logLikelihood
|
229 |
+
|
230 |
+
e1 = FrontierEntry(
|
231 |
+
program=e1.program,
|
232 |
+
logLikelihood=(
|
233 |
+
e1.logLikelihood +
|
234 |
+
e2.logLikelihood) /
|
235 |
+
2,
|
236 |
+
logPrior=e1.logPrior)
|
237 |
+
else:
|
238 |
+
e1 = y[p]
|
239 |
+
union.append(e1)
|
240 |
+
|
241 |
+
if foundDifference:
|
242 |
+
eprint(
|
243 |
+
"WARNING: Log likelihoods differed for the same program on the task %s.\n" %
|
244 |
+
(self.task.name),
|
245 |
+
"\tThis is acceptable only if the likelihood model is stochastic. Took the geometric mean of the likelihoods.")
|
246 |
+
|
247 |
+
return Frontier(union, self.task)
|
dreamcoder/grammar.py
ADDED
@@ -0,0 +1,1308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
|
3 |
+
from dreamcoder.frontier import *
|
4 |
+
from dreamcoder.program import *
|
5 |
+
from dreamcoder.type import *
|
6 |
+
from dreamcoder.utilities import *
|
7 |
+
|
8 |
+
import time
|
9 |
+
|
10 |
+
class GrammarFailure(Exception):
|
11 |
+
pass
|
12 |
+
|
13 |
+
class SketchEnumerationFailure(Exception):
|
14 |
+
pass
|
15 |
+
|
16 |
+
class NoCandidates(Exception):
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
class Grammar(object):
|
21 |
+
def __init__(self, logVariable, productions, continuationType=None):
|
22 |
+
self.logVariable = logVariable
|
23 |
+
self.productions = productions
|
24 |
+
|
25 |
+
self.continuationType = continuationType
|
26 |
+
|
27 |
+
self.expression2likelihood = dict((p, l) for l, _, p in productions)
|
28 |
+
self.expression2likelihood[Index(0)] = self.logVariable
|
29 |
+
|
30 |
+
def randomWeights(self, r):
|
31 |
+
"""returns a new grammar with random weights drawn from r. calls `r` w/ old weight"""
|
32 |
+
return Grammar(logVariable=r(self.logVariable),
|
33 |
+
productions=[(r(l),t,p)
|
34 |
+
for l,t,p in self.productions ],
|
35 |
+
continuationType=self.continuationType)
|
36 |
+
|
37 |
+
def strip_primitive_values(self):
|
38 |
+
return Grammar(logVariable=self.logVariable,
|
39 |
+
productions=[(l,t,strip_primitive_values(p))
|
40 |
+
for l,t,p in self.productions ],
|
41 |
+
continuationType=self.continuationType)
|
42 |
+
|
43 |
+
def unstrip_primitive_values(self):
|
44 |
+
return Grammar(logVariable=self.logVariable,
|
45 |
+
productions=[(l,t,unstrip_primitive_values(p))
|
46 |
+
for l,t,p in self.productions ],
|
47 |
+
continuationType=self.continuationType)
|
48 |
+
|
49 |
+
def __setstate__(self, state):
|
50 |
+
"""
|
51 |
+
Legacy support for loading grammar objects without the imperative type filled in
|
52 |
+
"""
|
53 |
+
assert 'logVariable' in state
|
54 |
+
assert 'productions' in state
|
55 |
+
if 'continuationType' in state:
|
56 |
+
continuationType = state['continuationType']
|
57 |
+
else:
|
58 |
+
if any( 'turtle' in str(t) for l,t,p in state['productions'] ):
|
59 |
+
continuationType = baseType("turtle")
|
60 |
+
elif any( 'tower' in str(t) for l,t,p in state['productions'] ):
|
61 |
+
continuationType = baseType("tower")
|
62 |
+
else:
|
63 |
+
continuationType = None
|
64 |
+
|
65 |
+
self.__init__(state['logVariable'], state['productions'], continuationType=continuationType)
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def fromProductions(productions, logVariable=0.0, continuationType=None):
|
69 |
+
"""Make a grammar from primitives and their relative logpriors."""
|
70 |
+
return Grammar(logVariable, [(l, p.infer(), p)
|
71 |
+
for l, p in productions],
|
72 |
+
continuationType=continuationType)
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def uniform(primitives, continuationType=None):
|
76 |
+
return Grammar(0.0, [(0.0, p.infer(), p) for p in primitives], continuationType=continuationType)
|
77 |
+
|
78 |
+
def __len__(self): return len(self.productions)
|
79 |
+
|
80 |
+
def __str__(self):
|
81 |
+
def productionKey(xxx_todo_changeme):
|
82 |
+
(l, t, p) = xxx_todo_changeme
|
83 |
+
return not isinstance(p, Primitive), l is not None and -l
|
84 |
+
if self.continuationType is not None:
|
85 |
+
lines = ["continuation : %s"%self.continuationType]
|
86 |
+
else:
|
87 |
+
lines = []
|
88 |
+
lines += ["%f\tt0\t$_" % self.logVariable]
|
89 |
+
for l, t, p in sorted(self.productions, key=productionKey):
|
90 |
+
if l is not None:
|
91 |
+
l = "%f\t%s\t%s" % (l, t, p)
|
92 |
+
else:
|
93 |
+
l = "-Inf\t%s\t%s" % (t, p)
|
94 |
+
if not t.isArrow() and isinstance(p, Invented):
|
95 |
+
try:
|
96 |
+
l += "\teval = %s" % (p.evaluate([]))
|
97 |
+
except BaseException:
|
98 |
+
pass
|
99 |
+
|
100 |
+
lines.append(l)
|
101 |
+
return "\n".join(lines)
|
102 |
+
|
103 |
+
def json(self):
|
104 |
+
j = {"logVariable": self.logVariable,
|
105 |
+
"productions": [{"expression": str(p), "logProbability": l}
|
106 |
+
for l, _, p in self.productions]}
|
107 |
+
if self.continuationType is not None:
|
108 |
+
j["continuationType"] = self.continuationType.json()
|
109 |
+
return j
|
110 |
+
|
111 |
+
def _immutable_code(self): return self.logVariable, tuple(self.productions)
|
112 |
+
|
113 |
+
def __eq__(self, o): return self._immutable_code() == o._immutable_code()
|
114 |
+
|
115 |
+
def __ne__(self, o): return not (self == o)
|
116 |
+
|
117 |
+
def __hash__(self): return hash(self._immutable_code())
|
118 |
+
|
119 |
+
@property
|
120 |
+
def primitives(self):
|
121 |
+
return [p for _, _, p in self.productions]
|
122 |
+
|
123 |
+
def removeProductions(self, ps):
|
124 |
+
return Grammar(
|
125 |
+
self.logVariable, [
|
126 |
+
(l, t, p) for (
|
127 |
+
l, t, p) in self.productions if p not in ps],
|
128 |
+
continuationType=self.continuationType)
|
129 |
+
|
130 |
+
def buildCandidates(self, request, context, environment,
|
131 |
+
# Should the log probabilities be normalized?
|
132 |
+
normalize=True,
|
133 |
+
# Should be returned a table mapping primitives to
|
134 |
+
# their candidate entry?
|
135 |
+
returnTable=False,
|
136 |
+
# Should we return probabilities vs log probabilities?
|
137 |
+
returnProbabilities=False,
|
138 |
+
# Must be a leaf (have no arguments)?
|
139 |
+
mustBeLeaf=False):
|
140 |
+
"""Primitives that are candidates for being used given a requested type
|
141 |
+
If returnTable is false (default): returns [((log)likelihood, tp, primitive, context)]
|
142 |
+
if returntable is true: returns {primitive: ((log)likelihood, tp, context)}"""
|
143 |
+
if returnProbabilities:
|
144 |
+
assert normalize
|
145 |
+
|
146 |
+
candidates = []
|
147 |
+
variableCandidates = []
|
148 |
+
for l, t, p in self.productions:
|
149 |
+
try:
|
150 |
+
newContext, t = t.instantiate(context)
|
151 |
+
newContext = newContext.unify(t.returns(), request)
|
152 |
+
t = t.apply(newContext)
|
153 |
+
if mustBeLeaf and t.isArrow():
|
154 |
+
continue
|
155 |
+
candidates.append((l, t, p, newContext))
|
156 |
+
except UnificationFailure:
|
157 |
+
continue
|
158 |
+
for j, t in enumerate(environment):
|
159 |
+
try:
|
160 |
+
newContext = context.unify(t.returns(), request)
|
161 |
+
t = t.apply(newContext)
|
162 |
+
if mustBeLeaf and t.isArrow():
|
163 |
+
continue
|
164 |
+
variableCandidates.append((t, Index(j), newContext))
|
165 |
+
except UnificationFailure:
|
166 |
+
continue
|
167 |
+
|
168 |
+
if self.continuationType == request:
|
169 |
+
terminalIndices = [v.i for t,v,k in variableCandidates if not t.isArrow()]
|
170 |
+
if terminalIndices:
|
171 |
+
smallestIndex = Index(min(terminalIndices))
|
172 |
+
variableCandidates = [(t,v,k) for t,v,k in variableCandidates
|
173 |
+
if t.isArrow() or v == smallestIndex]
|
174 |
+
|
175 |
+
candidates += [(self.logVariable - log(len(variableCandidates)), t, p, k)
|
176 |
+
for t, p, k in variableCandidates]
|
177 |
+
if candidates == []:
|
178 |
+
raise NoCandidates()
|
179 |
+
#eprint("candidates inside buildCandidates before norm:")
|
180 |
+
#eprint(candidates)
|
181 |
+
|
182 |
+
if normalize:
|
183 |
+
z = lse([l for l, t, p, k in candidates])
|
184 |
+
if returnProbabilities:
|
185 |
+
candidates = [(exp(l - z), t, p, k)
|
186 |
+
for l, t, p, k in candidates]
|
187 |
+
else:
|
188 |
+
candidates = [(l - z, t, p, k) for l, t, p, k in candidates]
|
189 |
+
|
190 |
+
#eprint("candidates inside buildCandidates after norm:")
|
191 |
+
#eprint(candidates)
|
192 |
+
|
193 |
+
if returnTable:
|
194 |
+
return {p: (l, t, k) for l, t, p, k in candidates}
|
195 |
+
else:
|
196 |
+
return candidates
|
197 |
+
|
198 |
+
|
199 |
+
def sample(self, request, maximumDepth=6, maxAttempts=None):
|
200 |
+
attempts = 0
|
201 |
+
|
202 |
+
while True:
|
203 |
+
try:
|
204 |
+
_, e = self._sample(
|
205 |
+
request, Context.EMPTY, [], maximumDepth=maximumDepth)
|
206 |
+
return e
|
207 |
+
except NoCandidates:
|
208 |
+
if maxAttempts is not None:
|
209 |
+
attempts += 1
|
210 |
+
if attempts > maxAttempts:
|
211 |
+
return None
|
212 |
+
continue
|
213 |
+
|
214 |
+
def _sample(self, request, context, environment, maximumDepth):
|
215 |
+
if request.isArrow():
|
216 |
+
context, expression = self._sample(
|
217 |
+
request.arguments[1], context, [
|
218 |
+
request.arguments[0]] + environment, maximumDepth)
|
219 |
+
return context, Abstraction(expression)
|
220 |
+
|
221 |
+
candidates = self.buildCandidates(request, context, environment,
|
222 |
+
normalize=True,
|
223 |
+
returnProbabilities=True,
|
224 |
+
# Force it to terminate in a
|
225 |
+
# leaf; a primitive with no
|
226 |
+
# function arguments
|
227 |
+
mustBeLeaf=maximumDepth <= 1)
|
228 |
+
#eprint("candidates:")
|
229 |
+
#eprint(candidates)
|
230 |
+
newType, chosenPrimitive, context = sampleDistribution(candidates)
|
231 |
+
|
232 |
+
# Sample the arguments
|
233 |
+
xs = newType.functionArguments()
|
234 |
+
returnValue = chosenPrimitive
|
235 |
+
|
236 |
+
for x in xs:
|
237 |
+
x = x.apply(context)
|
238 |
+
context, x = self._sample(x, context, environment, maximumDepth - 1)
|
239 |
+
returnValue = Application(returnValue, x)
|
240 |
+
|
241 |
+
return context, returnValue
|
242 |
+
|
243 |
+
def likelihoodSummary(self, context, environment, request, expression, silent=False):
|
244 |
+
if request.isArrow():
|
245 |
+
if not isinstance(expression, Abstraction):
|
246 |
+
if not silent:
|
247 |
+
eprint("Request is an arrow but I got", expression)
|
248 |
+
return context, None
|
249 |
+
return self.likelihoodSummary(context,
|
250 |
+
[request.arguments[0]] + environment,
|
251 |
+
request.arguments[1],
|
252 |
+
expression.body,
|
253 |
+
silent=silent)
|
254 |
+
# Build the candidates
|
255 |
+
candidates = self.buildCandidates(request, context, environment,
|
256 |
+
normalize=False,
|
257 |
+
returnTable=True)
|
258 |
+
|
259 |
+
# A list of everything that would have been possible to use here
|
260 |
+
possibles = [p for p in candidates.keys() if not p.isIndex]
|
261 |
+
numberOfVariables = sum(p.isIndex for p in candidates.keys())
|
262 |
+
if numberOfVariables > 0:
|
263 |
+
possibles += [Index(0)]
|
264 |
+
|
265 |
+
f, xs = expression.applicationParse()
|
266 |
+
|
267 |
+
if f not in candidates:
|
268 |
+
if self.continuationType is not None and f.isIndex:
|
269 |
+
ls = LikelihoodSummary()
|
270 |
+
ls.constant = NEGATIVEINFINITY
|
271 |
+
return ls
|
272 |
+
|
273 |
+
if not silent:
|
274 |
+
eprint(f, "Not in candidates")
|
275 |
+
eprint("Candidates is", candidates)
|
276 |
+
#eprint("grammar:", grammar.productions)
|
277 |
+
eprint("request is", request)
|
278 |
+
eprint("xs", xs)
|
279 |
+
eprint("environment", environment)
|
280 |
+
assert False
|
281 |
+
return context, None
|
282 |
+
|
283 |
+
thisSummary = LikelihoodSummary()
|
284 |
+
thisSummary.record(f, possibles,
|
285 |
+
constant= -math.log(numberOfVariables) if f.isIndex else 0)
|
286 |
+
|
287 |
+
_, tp, context = candidates[f]
|
288 |
+
argumentTypes = tp.functionArguments()
|
289 |
+
if len(xs) != len(argumentTypes):
|
290 |
+
eprint("PANIC: not enough arguments for the type")
|
291 |
+
eprint("request", request)
|
292 |
+
eprint("tp", tp)
|
293 |
+
eprint("expression", expression)
|
294 |
+
eprint("xs", xs)
|
295 |
+
eprint("argumentTypes", argumentTypes)
|
296 |
+
# This should absolutely never occur
|
297 |
+
raise GrammarFailure((context, environment, request, expression))
|
298 |
+
|
299 |
+
for argumentType, argument in zip(argumentTypes, xs):
|
300 |
+
argumentType = argumentType.apply(context)
|
301 |
+
context, newSummary = self.likelihoodSummary(
|
302 |
+
context, environment, argumentType, argument, silent=silent)
|
303 |
+
if newSummary is None:
|
304 |
+
return context, None
|
305 |
+
thisSummary.join(newSummary)
|
306 |
+
|
307 |
+
return context, thisSummary
|
308 |
+
|
309 |
+
def bestFirstEnumeration(self, request):
|
310 |
+
from heapq import heappush, heappop
|
311 |
+
|
312 |
+
pq = []
|
313 |
+
|
314 |
+
def choices(parentCost, xs):
|
315 |
+
for c, x in xs:
|
316 |
+
heappush(pq, (parentCost + c, x))
|
317 |
+
|
318 |
+
def g(parentCost, request, _=None,
|
319 |
+
context=None, environment=[],
|
320 |
+
k=None):
|
321 |
+
"""
|
322 |
+
k is a continuation.
|
323 |
+
k: Expects to be called with MDL, context, expression.
|
324 |
+
"""
|
325 |
+
|
326 |
+
assert k is not None
|
327 |
+
if context is None:
|
328 |
+
context = Context.EMPTY
|
329 |
+
|
330 |
+
if request.isArrow():
|
331 |
+
g(parentCost,
|
332 |
+
request.arguments[1],
|
333 |
+
context=context,
|
334 |
+
environment=[request.arguments[0]] + environment,
|
335 |
+
k=lambda MDL,
|
336 |
+
newContext,
|
337 |
+
p: k(MDL,
|
338 |
+
newContext,
|
339 |
+
Abstraction(p)))
|
340 |
+
else:
|
341 |
+
candidates = self.buildCandidates(request,
|
342 |
+
context,
|
343 |
+
environment,
|
344 |
+
normalize=True,
|
345 |
+
returnProbabilities=False,
|
346 |
+
returnTable=True)
|
347 |
+
choices(parentCost,
|
348 |
+
[(-f_ll_tp_newContext[1][0],
|
349 |
+
lambda: ga(parentCost - f_ll_tp_newContext[1][0],
|
350 |
+
f_ll_tp_newContext[0],
|
351 |
+
f_ll_tp_newContext[1][1].functionArguments(),
|
352 |
+
context=f_ll_tp_newContext[1][2],
|
353 |
+
environment=environment,
|
354 |
+
k=k)) for f_ll_tp_newContext in iter(candidates.items())])
|
355 |
+
|
356 |
+
def ga(costSoFar, f, argumentTypes, _=None,
|
357 |
+
context=None, environment=None,
|
358 |
+
k=None):
|
359 |
+
if argumentTypes == []:
|
360 |
+
k(costSoFar, context, f)
|
361 |
+
else:
|
362 |
+
t1 = argumentTypes[0].apply(context)
|
363 |
+
g(costSoFar, t1, context=context, environment=environment,
|
364 |
+
k=lambda newCost, newContext, argument:
|
365 |
+
ga(newCost, Application(f, argument), argumentTypes[1:],
|
366 |
+
context=newContext, environment=environment,
|
367 |
+
k=k))
|
368 |
+
|
369 |
+
def receiveResult(MDL, _, expression):
|
370 |
+
heappush(pq, (MDL, expression))
|
371 |
+
|
372 |
+
g(0., request, context=Context.EMPTY, environment=[], k=receiveResult)
|
373 |
+
frontier = []
|
374 |
+
while len(frontier) < 10**3:
|
375 |
+
MDL, action = heappop(pq)
|
376 |
+
if isinstance(action, Program):
|
377 |
+
expression = action
|
378 |
+
frontier.append(expression)
|
379 |
+
#eprint("Enumerated program",expression,-MDL,self.closedLogLikelihood(request, expression))
|
380 |
+
else:
|
381 |
+
action()
|
382 |
+
|
383 |
+
def closedLikelihoodSummary(self, request, expression, silent=False):
|
384 |
+
try:
|
385 |
+
context, summary = self.likelihoodSummary(Context.EMPTY, [], request, expression, silent=silent)
|
386 |
+
except GrammarFailure as e:
|
387 |
+
failureExport = 'failures/grammarFailure%s.pickle' % (
|
388 |
+
time.time() + getPID())
|
389 |
+
eprint("PANIC: Grammar failure, exporting to ", failureExport)
|
390 |
+
with open(failureExport, 'wb') as handle:
|
391 |
+
pickle.dump((e, self, request, expression), handle)
|
392 |
+
assert False
|
393 |
+
|
394 |
+
return summary
|
395 |
+
|
396 |
+
def logLikelihood(self, request, expression):
|
397 |
+
summary = self.closedLikelihoodSummary(request, expression)
|
398 |
+
if summary is None:
|
399 |
+
eprint(
|
400 |
+
"FATAL: program [ %s ] does not have a likelihood summary." %
|
401 |
+
expression, "r = ", request, "\n", self)
|
402 |
+
assert False
|
403 |
+
return summary.logLikelihood(self)
|
404 |
+
|
405 |
+
def rescoreFrontier(self, frontier):
|
406 |
+
return Frontier([FrontierEntry(e.program,
|
407 |
+
logPrior=self.logLikelihood(frontier.task.request, e.program),
|
408 |
+
logLikelihood=e.logLikelihood)
|
409 |
+
for e in frontier],
|
410 |
+
frontier.task)
|
411 |
+
|
412 |
+
def productionUses(self, frontiers):
|
413 |
+
"""Returns the expected number of times that each production was used. {production: expectedUses}"""
|
414 |
+
frontiers = [self.rescoreFrontier(f).normalize()
|
415 |
+
for f in frontiers if not f.empty]
|
416 |
+
uses = {p: 0. for p in self.primitives}
|
417 |
+
for f in frontiers:
|
418 |
+
for e in f:
|
419 |
+
summary = self.closedLikelihoodSummary(f.task.request,
|
420 |
+
e.program)
|
421 |
+
for p, u in summary.uses:
|
422 |
+
uses[p] += u * math.exp(e.logPosterior)
|
423 |
+
return uses
|
424 |
+
|
425 |
+
def insideOutside(self, frontiers, pseudoCounts, iterations=1):
|
426 |
+
# Replace programs with (likelihood summary, uses)
|
427 |
+
frontiers = [ Frontier([ FrontierEntry((summary, summary.toUses()),
|
428 |
+
logPrior=summary.logLikelihood(self),
|
429 |
+
logLikelihood=e.logLikelihood)
|
430 |
+
for e in f
|
431 |
+
for summary in [self.closedLikelihoodSummary(f.task.request, e.program)] ],
|
432 |
+
task=f.task)
|
433 |
+
for f in frontiers ]
|
434 |
+
|
435 |
+
g = self
|
436 |
+
for i in range(iterations):
|
437 |
+
u = Uses()
|
438 |
+
for f in frontiers:
|
439 |
+
f = f.normalize()
|
440 |
+
for e in f:
|
441 |
+
_, eu = e.program
|
442 |
+
u += math.exp(e.logPosterior) * eu
|
443 |
+
|
444 |
+
lv = math.log(u.actualVariables + pseudoCounts) - \
|
445 |
+
math.log(u.possibleVariables + pseudoCounts)
|
446 |
+
g = Grammar(lv,
|
447 |
+
[ (math.log(u.actualUses.get(p,0.) + pseudoCounts) - \
|
448 |
+
math.log(u.possibleUses.get(p,0.) + pseudoCounts),
|
449 |
+
t,p)
|
450 |
+
for _,t,p in g.productions ],
|
451 |
+
continuationType=self.continuationType)
|
452 |
+
if i < iterations - 1:
|
453 |
+
frontiers = [Frontier([ FrontierEntry((summary, uses),
|
454 |
+
logPrior=summary.logLikelihood(g),
|
455 |
+
logLikelihood=e.logLikelihood)
|
456 |
+
for e in f
|
457 |
+
for (summary, uses) in [e.program] ],
|
458 |
+
task=f.task)
|
459 |
+
for f in frontiers ]
|
460 |
+
return g
|
461 |
+
|
462 |
+
def frontierMDL(self, frontier):
|
463 |
+
return max( e.logLikelihood + self.logLikelihood(frontier.task.request, e.program)
|
464 |
+
for e in frontier )
|
465 |
+
|
466 |
+
|
467 |
+
def enumeration(self,context,environment,request,upperBound,
|
468 |
+
maximumDepth=20,
|
469 |
+
lowerBound=0.):
|
470 |
+
'''Enumerates all programs whose MDL satisfies: lowerBound <= MDL < upperBound'''
|
471 |
+
if upperBound < 0 or maximumDepth == 1:
|
472 |
+
return
|
473 |
+
|
474 |
+
if request.isArrow():
|
475 |
+
v = request.arguments[0]
|
476 |
+
for l, newContext, b in self.enumeration(context, [v] + environment,
|
477 |
+
request.arguments[1],
|
478 |
+
upperBound=upperBound,
|
479 |
+
lowerBound=lowerBound,
|
480 |
+
maximumDepth=maximumDepth):
|
481 |
+
yield l, newContext, Abstraction(b)
|
482 |
+
|
483 |
+
else:
|
484 |
+
candidates = self.buildCandidates(request, context, environment,
|
485 |
+
normalize=True)
|
486 |
+
|
487 |
+
for l, t, p, newContext in candidates:
|
488 |
+
mdl = -l
|
489 |
+
if not (mdl < upperBound):
|
490 |
+
continue
|
491 |
+
|
492 |
+
xs = t.functionArguments()
|
493 |
+
for aL, aK, application in\
|
494 |
+
self.enumerateApplication(newContext, environment, p, xs,
|
495 |
+
upperBound=upperBound + l,
|
496 |
+
lowerBound=lowerBound + l,
|
497 |
+
maximumDepth=maximumDepth - 1):
|
498 |
+
yield aL + l, aK, application
|
499 |
+
|
500 |
+
def enumerateApplication(self, context, environment,
|
501 |
+
function, argumentRequests,
|
502 |
+
# Upper bound on the description length of all of
|
503 |
+
# the arguments
|
504 |
+
upperBound,
|
505 |
+
# Lower bound on the description length of all of
|
506 |
+
# the arguments
|
507 |
+
lowerBound=0.,
|
508 |
+
maximumDepth=20,
|
509 |
+
originalFunction=None,
|
510 |
+
argumentIndex=0):
|
511 |
+
if upperBound < 0. or maximumDepth == 1:
|
512 |
+
return
|
513 |
+
if originalFunction is None:
|
514 |
+
originalFunction = function
|
515 |
+
|
516 |
+
if argumentRequests == []:
|
517 |
+
if lowerBound <= 0. and 0. < upperBound:
|
518 |
+
yield 0., context, function
|
519 |
+
else:
|
520 |
+
return
|
521 |
+
else:
|
522 |
+
argRequest = argumentRequests[0].apply(context)
|
523 |
+
laterRequests = argumentRequests[1:]
|
524 |
+
for argL, newContext, arg in self.enumeration(context, environment, argRequest,
|
525 |
+
upperBound=upperBound,
|
526 |
+
lowerBound=0.,
|
527 |
+
maximumDepth=maximumDepth):
|
528 |
+
if violatesSymmetry(originalFunction, arg, argumentIndex):
|
529 |
+
continue
|
530 |
+
|
531 |
+
newFunction = Application(function, arg)
|
532 |
+
for resultL, resultK, result in self.enumerateApplication(newContext, environment, newFunction,
|
533 |
+
laterRequests,
|
534 |
+
upperBound=upperBound + argL,
|
535 |
+
lowerBound=lowerBound + argL,
|
536 |
+
maximumDepth=maximumDepth,
|
537 |
+
originalFunction=originalFunction,
|
538 |
+
argumentIndex=argumentIndex + 1):
|
539 |
+
yield resultL + argL, resultK, result
|
540 |
+
|
541 |
+
def sketchEnumeration(self,context,environment,request,sk,upperBound,
|
542 |
+
maximumDepth=20,
|
543 |
+
lowerBound=0.):
|
544 |
+
'''Enumerates all sketch instantiations whose MDL satisfies: lowerBound <= MDL < upperBound'''
|
545 |
+
if upperBound < 0. or maximumDepth == 1:
|
546 |
+
return
|
547 |
+
|
548 |
+
if sk.isHole:
|
549 |
+
yield from self.enumeration(context, environment, request, upperBound,
|
550 |
+
maximumDepth=maximumDepth,
|
551 |
+
lowerBound=lowerBound)
|
552 |
+
elif request.isArrow():
|
553 |
+
assert sk.isAbstraction
|
554 |
+
v = request.arguments[0]
|
555 |
+
for l, newContext, b in self.sketchEnumeration(context, [v] + environment,
|
556 |
+
request.arguments[1],
|
557 |
+
sk.body,
|
558 |
+
upperBound=upperBound,
|
559 |
+
lowerBound=lowerBound,
|
560 |
+
maximumDepth=maximumDepth):
|
561 |
+
yield l, newContext, Abstraction(b)
|
562 |
+
|
563 |
+
else:
|
564 |
+
f, xs = sk.applicationParse()
|
565 |
+
if f.isIndex:
|
566 |
+
ft = environment[f.i].apply(context)
|
567 |
+
elif f.isInvented or f.isPrimitive:
|
568 |
+
context, ft = f.tp.instantiate(context)
|
569 |
+
elif f.isAbstraction:
|
570 |
+
assert False, "sketch is not in beta longform"
|
571 |
+
elif f.isHole:
|
572 |
+
assert False, "hole as function not yet supported"
|
573 |
+
elif f.isApplication:
|
574 |
+
assert False, "should never happen - bug in applicationParse"
|
575 |
+
else: assert False
|
576 |
+
|
577 |
+
try: context = context.unify(ft.returns(), request)
|
578 |
+
except UnificationFailure:
|
579 |
+
print("Exception: sketch is ill-typed")
|
580 |
+
return #so that we can continue evaluating
|
581 |
+
# raise SketchEnumerationFailure() #"sketch is ill-typed"
|
582 |
+
ft = ft.apply(context)
|
583 |
+
argumentRequests = ft.functionArguments()
|
584 |
+
|
585 |
+
assert len(argumentRequests) == len(xs)
|
586 |
+
|
587 |
+
yield from self.sketchApplication(context, environment,
|
588 |
+
f, xs, argumentRequests,
|
589 |
+
upperBound=upperBound,
|
590 |
+
lowerBound=lowerBound,
|
591 |
+
maximumDepth=maximumDepth - 1)
|
592 |
+
|
593 |
+
|
594 |
+
def sketchApplication(self, context, environment,
|
595 |
+
function, arguments, argumentRequests,
|
596 |
+
# Upper bound on the description length of all of
|
597 |
+
# the arguments
|
598 |
+
upperBound,
|
599 |
+
# Lower bound on the description length of all of
|
600 |
+
# the arguments
|
601 |
+
lowerBound=0.,
|
602 |
+
maximumDepth=20):
|
603 |
+
if upperBound < 0. or maximumDepth == 1:
|
604 |
+
return
|
605 |
+
|
606 |
+
if argumentRequests == []:
|
607 |
+
if lowerBound <= 0. and 0. < upperBound:
|
608 |
+
yield 0., context, function
|
609 |
+
else:
|
610 |
+
return
|
611 |
+
else:
|
612 |
+
argRequest = argumentRequests[0].apply(context)
|
613 |
+
laterRequests = argumentRequests[1:]
|
614 |
+
firstSketch = arguments[0]
|
615 |
+
laterSketches = arguments[1:]
|
616 |
+
for argL, newContext, arg in self.sketchEnumeration(context, environment, argRequest,
|
617 |
+
firstSketch,
|
618 |
+
upperBound=upperBound,
|
619 |
+
lowerBound=0.,
|
620 |
+
maximumDepth=maximumDepth):
|
621 |
+
|
622 |
+
newFunction = Application(function, arg)
|
623 |
+
for resultL, resultK, result in self.sketchApplication(newContext, environment, newFunction,
|
624 |
+
laterSketches, laterRequests,
|
625 |
+
upperBound=upperBound + argL,
|
626 |
+
lowerBound=lowerBound + argL,
|
627 |
+
maximumDepth=maximumDepth):
|
628 |
+
|
629 |
+
yield resultL + argL, resultK, result
|
630 |
+
|
631 |
+
def sketchLogLikelihood(self, request, full, sk, context=Context.EMPTY, environment=[]):
|
632 |
+
"""
|
633 |
+
calculates mdl of full program 'full' from sketch 'sk'
|
634 |
+
"""
|
635 |
+
if sk.isHole:
|
636 |
+
_, summary = self.likelihoodSummary(context, environment, request, full)
|
637 |
+
if summary is None:
|
638 |
+
eprint(
|
639 |
+
"FATAL: program [ %s ] does not have a likelihood summary." %
|
640 |
+
full, "r = ", request, "\n", self)
|
641 |
+
assert False
|
642 |
+
return summary.logLikelihood(self), context
|
643 |
+
|
644 |
+
elif request.isArrow():
|
645 |
+
assert sk.isAbstraction and full.isAbstraction
|
646 |
+
#assert sk.f == full.f #is this right? or do i need to recurse?
|
647 |
+
v = request.arguments[0]
|
648 |
+
return self.sketchLogLikelihood(request.arguments[1], full.body, sk.body, context=context, environment=[v] + environment)
|
649 |
+
|
650 |
+
else:
|
651 |
+
sk_f, sk_xs = sk.applicationParse()
|
652 |
+
full_f, full_xs = full.applicationParse()
|
653 |
+
if sk_f.isIndex:
|
654 |
+
assert sk_f == full_f, "sketch and full program don't match on an index"
|
655 |
+
ft = environment[sk_f.i].apply(context)
|
656 |
+
elif sk_f.isInvented or sk_f.isPrimitive:
|
657 |
+
assert sk_f == full_f, "sketch and full program don't match on a primitive"
|
658 |
+
context, ft = sk_f.tp.instantiate(context)
|
659 |
+
elif sk_f.isAbstraction:
|
660 |
+
assert False, "sketch is not in beta longform"
|
661 |
+
elif sk_f.isHole:
|
662 |
+
assert False, "hole as function not yet supported"
|
663 |
+
elif sk_f.isApplication:
|
664 |
+
assert False, "should never happen - bug in applicationParse"
|
665 |
+
else: assert False
|
666 |
+
|
667 |
+
try: context = context.unify(ft.returns(), request)
|
668 |
+
except UnificationFailure: assert False, "sketch is ill-typed"
|
669 |
+
ft = ft.apply(context)
|
670 |
+
argumentRequests = ft.functionArguments()
|
671 |
+
|
672 |
+
assert len(argumentRequests) == len(sk_xs) == len(full_xs) #this might not be true if holes??
|
673 |
+
|
674 |
+
return self.sketchllApplication(context, environment,
|
675 |
+
sk_f, sk_xs, full_f, full_xs, argumentRequests)
|
676 |
+
|
677 |
+
def sketchllApplication(self, context, environment,
|
678 |
+
sk_function, sk_arguments, full_function, full_arguments, argumentRequests):
|
679 |
+
if argumentRequests == []:
|
680 |
+
return torch.tensor([0.]).cuda(), context #does this make sense?
|
681 |
+
else:
|
682 |
+
argRequest = argumentRequests[0].apply(context)
|
683 |
+
laterRequests = argumentRequests[1:]
|
684 |
+
|
685 |
+
sk_firstSketch = sk_arguments[0]
|
686 |
+
full_firstSketch = full_arguments[0]
|
687 |
+
sk_laterSketches = sk_arguments[1:]
|
688 |
+
full_laterSketches = full_arguments[1:]
|
689 |
+
|
690 |
+
argL, newContext = self.sketchLogLikelihood(argRequest, full_firstSketch, sk_firstSketch, context=context, environment=environment)
|
691 |
+
|
692 |
+
#finish this...
|
693 |
+
sk_newFunction = Application(sk_function, sk_firstSketch) # is this redundant? maybe
|
694 |
+
full_newFunction = Application(full_function, full_firstSketch)
|
695 |
+
|
696 |
+
resultL, context = self.sketchllApplication(newContext, environment, sk_newFunction, sk_laterSketches,
|
697 |
+
full_newFunction, full_laterSketches, laterRequests)
|
698 |
+
|
699 |
+
return resultL + argL, context
|
700 |
+
|
701 |
+
|
702 |
+
def enumerateNearby(self, request, expr, distance=3.0):
|
703 |
+
"""Enumerate programs with local mutations in subtrees with small description length"""
|
704 |
+
if distance <= 0:
|
705 |
+
yield expr
|
706 |
+
else:
|
707 |
+
def mutations(tp, loss):
|
708 |
+
for l, _, expr in self.enumeration(
|
709 |
+
Context.EMPTY, [], tp, distance - loss):
|
710 |
+
yield expr, l
|
711 |
+
yield from Mutator(self, mutations).execute(expr, request)
|
712 |
+
|
713 |
+
|
714 |
+
def enumerateHoles(self, request, expr, k=3, return_obj=Hole):
|
715 |
+
"""Enumerate programs with a single hole within mdl distance"""
|
716 |
+
#TODO: make it possible to enumerate sketches with multiple holes
|
717 |
+
def mutations(tp, loss, is_left_application=False):
|
718 |
+
"""
|
719 |
+
to allow applications lhs to become a hole,
|
720 |
+
remove the condition below and ignore all the is_left_application kwds
|
721 |
+
"""
|
722 |
+
if not is_left_application:
|
723 |
+
yield return_obj(), 0
|
724 |
+
top_k = []
|
725 |
+
for expr, l in Mutator(self, mutations).execute(expr, request):
|
726 |
+
if len(top_k) > 0:
|
727 |
+
i, v = min(enumerate(top_k), key=lambda x:x[1][1])
|
728 |
+
if l > v[1]:
|
729 |
+
if len(top_k) >= k:
|
730 |
+
top_k[i] = (expr, l)
|
731 |
+
else:
|
732 |
+
top_k.append((expr, l))
|
733 |
+
elif len(top_k) < k:
|
734 |
+
top_k.append((expr, l))
|
735 |
+
else:
|
736 |
+
top_k.append((expr, l))
|
737 |
+
return sorted(top_k, key=lambda x:-x[1])
|
738 |
+
|
739 |
+
def untorch(self):
|
740 |
+
return Grammar(self.logVariable.data.tolist()[0],
|
741 |
+
[ (l.data.tolist()[0], t, p)
|
742 |
+
for l, t, p in self.productions],
|
743 |
+
continuationType=self.continuationType)
|
744 |
+
|
745 |
+
class LikelihoodSummary(object):
|
746 |
+
'''Summarizes the terms that will be used in a likelihood calculation'''
|
747 |
+
|
748 |
+
def __init__(self):
|
749 |
+
self.uses = {}
|
750 |
+
self.normalizers = {}
|
751 |
+
self.constant = 0.
|
752 |
+
|
753 |
+
def __str__(self):
|
754 |
+
return """LikelihoodSummary(constant = %f,
|
755 |
+
uses = {%s},
|
756 |
+
normalizers = {%s})""" % (self.constant,
|
757 |
+
", ".join(
|
758 |
+
"%s: %d" % (k,
|
759 |
+
v) for k,
|
760 |
+
v in self.uses.items()),
|
761 |
+
", ".join(
|
762 |
+
"%s: %d" % (k,
|
763 |
+
v) for k,
|
764 |
+
v in self.normalizers.items()))
|
765 |
+
|
766 |
+
def record(self, actual, possibles, constant=0.):
|
767 |
+
# Variables are all normalized to be $0
|
768 |
+
if isinstance(actual, Index):
|
769 |
+
actual = Index(0)
|
770 |
+
|
771 |
+
# Make it something that we can hash
|
772 |
+
possibles = frozenset(sorted(possibles, key=hash))
|
773 |
+
|
774 |
+
self.constant += constant
|
775 |
+
self.uses[actual] = self.uses.get(actual, 0) + 1
|
776 |
+
self.normalizers[possibles] = self.normalizers.get(possibles, 0) + 1
|
777 |
+
|
778 |
+
def join(self, other):
|
779 |
+
self.constant += other.constant
|
780 |
+
for k, v in other.uses.items():
|
781 |
+
self.uses[k] = self.uses.get(k, 0) + v
|
782 |
+
for k, v in other.normalizers.items():
|
783 |
+
self.normalizers[k] = self.normalizers.get(k, 0) + v
|
784 |
+
|
785 |
+
def logLikelihood(self, grammar):
|
786 |
+
return self.constant + \
|
787 |
+
sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items()) - \
|
788 |
+
sum(count * lse([grammar.expression2likelihood[p] for p in ps])
|
789 |
+
for ps, count in self.normalizers.items())
|
790 |
+
def logLikelihood_overlyGeneral(self, grammar):
|
791 |
+
"""Calculates log likelihood of this summary, given that the summary might refer to productions that don't occur in the grammar"""
|
792 |
+
return self.constant + \
|
793 |
+
sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items()) - \
|
794 |
+
sum(count * lse([grammar.expression2likelihood.get(p,NEGATIVEINFINITY) for p in ps])
|
795 |
+
for ps, count in self.normalizers.items())
|
796 |
+
def numerator(self, grammar):
|
797 |
+
return self.constant + \
|
798 |
+
sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items())
|
799 |
+
def denominator(self, grammar):
|
800 |
+
return \
|
801 |
+
sum(count * lse([grammar.expression2likelihood[p] for p in ps])
|
802 |
+
for ps, count in self.normalizers.items())
|
803 |
+
def toUses(self):
|
804 |
+
from collections import Counter
|
805 |
+
|
806 |
+
possibleVariables = sum( count if Index(0) in ps else 0
|
807 |
+
for ps, count in self.normalizers.items() )
|
808 |
+
actualVariables = self.uses.get(Index(0), 0.)
|
809 |
+
actualUses = {k: v
|
810 |
+
for k, v in self.uses.items()
|
811 |
+
if not k.isIndex }
|
812 |
+
possibleUses = dict(Counter(p
|
813 |
+
for ps, count in self.normalizers.items()
|
814 |
+
for p_ in ps
|
815 |
+
if not p_.isIndex
|
816 |
+
for p in [p_]*count ))
|
817 |
+
return Uses(possibleVariables, actualVariables,
|
818 |
+
possibleUses, actualUses)
|
819 |
+
|
820 |
+
|
821 |
+
class Uses(object):
|
822 |
+
'''Tracks uses of different grammar productions'''
|
823 |
+
|
824 |
+
def __init__(self, possibleVariables=0., actualVariables=0.,
|
825 |
+
possibleUses={}, actualUses={}):
|
826 |
+
self.actualVariables = actualVariables
|
827 |
+
self.possibleVariables = possibleVariables
|
828 |
+
self.possibleUses = possibleUses
|
829 |
+
self.actualUses = actualUses
|
830 |
+
|
831 |
+
def __str__(self):
|
832 |
+
return "Uses(actualVariables = %f, possibleVariables = %f, actualUses = %s, possibleUses = %s)" %\
|
833 |
+
(self.actualVariables, self.possibleVariables, self.actualUses, self.possibleUses)
|
834 |
+
|
835 |
+
def __repr__(self): return str(self)
|
836 |
+
|
837 |
+
def __mul__(self, a):
|
838 |
+
return Uses(a * self.possibleVariables,
|
839 |
+
a * self.actualVariables,
|
840 |
+
{p: a * u for p, u in self.possibleUses.items()},
|
841 |
+
{p: a * u for p, u in self.actualUses.items()})
|
842 |
+
|
843 |
+
def __imul__(self, a):
|
844 |
+
self.possibleVariables *= a
|
845 |
+
self.actualVariables *= a
|
846 |
+
for p in self.possibleUses:
|
847 |
+
self.possibleUses[p] *= a
|
848 |
+
for p in self.actualUses:
|
849 |
+
self.actualUses[p] *= a
|
850 |
+
return self
|
851 |
+
|
852 |
+
def __rmul__(self, a):
|
853 |
+
return self * a
|
854 |
+
|
855 |
+
def __radd__(self, o):
|
856 |
+
if o == 0:
|
857 |
+
return self
|
858 |
+
return self + o
|
859 |
+
|
860 |
+
def __add__(self, o):
|
861 |
+
if o == 0:
|
862 |
+
return self
|
863 |
+
|
864 |
+
def merge(x, y):
|
865 |
+
z = x.copy()
|
866 |
+
for k, v in y.items():
|
867 |
+
z[k] = v + x.get(k, 0.)
|
868 |
+
return z
|
869 |
+
return Uses(self.possibleVariables + o.possibleVariables,
|
870 |
+
self.actualVariables + o.actualVariables,
|
871 |
+
merge(self.possibleUses, o.possibleUses),
|
872 |
+
merge(self.actualUses, o.actualUses))
|
873 |
+
|
874 |
+
def __iadd__(self, o):
|
875 |
+
self.possibleVariables += o.possibleVariables
|
876 |
+
self.actualVariables += o.actualVariables
|
877 |
+
for k, v in o.possibleUses.items():
|
878 |
+
self.possibleUses[k] = self.possibleUses.get(k, 0.) + v
|
879 |
+
for k, v in o.actualUses.items():
|
880 |
+
self.actualUses[k] = self.actualUses.get(k, 0.) + v
|
881 |
+
return self
|
882 |
+
|
883 |
+
@staticmethod
|
884 |
+
def join(z, *weightedUses):
|
885 |
+
"""Consumes weightedUses"""
|
886 |
+
if not weightedUses:
|
887 |
+
Uses.empty
|
888 |
+
if len(weightedUses) == 1:
|
889 |
+
return weightedUses[0][1]
|
890 |
+
for w, u in weightedUses:
|
891 |
+
u *= exp(w - z)
|
892 |
+
total = Uses()
|
893 |
+
total.possibleVariables = sum(
|
894 |
+
u.possibleVariables for _, u in weightedUses)
|
895 |
+
total.actualVariables = sum(u.actualVariables for _, u in weightedUses)
|
896 |
+
total.possibleUses = defaultdict(float)
|
897 |
+
total.actualUses = defaultdict(float)
|
898 |
+
for _, u in weightedUses:
|
899 |
+
for k, v in u.possibleUses.items():
|
900 |
+
total.possibleUses[k] += v
|
901 |
+
for k, v in u.actualUses.items():
|
902 |
+
total.actualUses[k] += v
|
903 |
+
return total
|
904 |
+
|
905 |
+
|
906 |
+
Uses.empty = Uses()
|
907 |
+
|
908 |
+
class ContextualGrammar:
|
909 |
+
def __init__(self, noParent, variableParent, library):
|
910 |
+
self.noParent, self.variableParent, self.library = noParent, variableParent, library
|
911 |
+
|
912 |
+
self.productions = [(None,t,p) for _,t,p in self.noParent.productions ]
|
913 |
+
self.primitives = [p for _,_2,p in self.productions ]
|
914 |
+
|
915 |
+
self.continuationType = noParent.continuationType
|
916 |
+
assert variableParent.continuationType == self.continuationType
|
917 |
+
|
918 |
+
assert set(noParent.primitives) == set(variableParent.primitives)
|
919 |
+
assert set(variableParent.primitives) == set(library.keys())
|
920 |
+
for e,gs in library.items():
|
921 |
+
assert len(gs) == len(e.infer().functionArguments())
|
922 |
+
for g in gs:
|
923 |
+
assert set(g.primitives) == set(library.keys())
|
924 |
+
assert g.continuationType == self.continuationType
|
925 |
+
|
926 |
+
def untorch(self):
|
927 |
+
return ContextualGrammar(self.noParent.untorch(), self.variableParent.untorch(),
|
928 |
+
{e: [g.untorch() for g in gs ]
|
929 |
+
for e,gs in self.library.items() })
|
930 |
+
|
931 |
+
def randomWeights(self, r):
|
932 |
+
"""returns a new grammar with random weights drawn from r. calls `r` w/ old weight"""
|
933 |
+
return ContextualGrammar(self.noParent.randomWeights(r),
|
934 |
+
self.variableParent.randomWeights(r),
|
935 |
+
{e: [g.randomWeights(r) for g in gs]
|
936 |
+
for e,gs in self.library.items() })
|
937 |
+
def __str__(self):
|
938 |
+
lines = ["No parent:",str(self.noParent),"",
|
939 |
+
"Variable parent:",str(self.variableParent),"",
|
940 |
+
""]
|
941 |
+
for e,gs in self.library.items():
|
942 |
+
for j,g in enumerate(gs):
|
943 |
+
lines.extend(["Parent %s, argument index %s"%(e,j),
|
944 |
+
str(g),
|
945 |
+
""])
|
946 |
+
return "\n".join(lines)
|
947 |
+
|
948 |
+
def json(self):
|
949 |
+
return {"noParent": self.noParent.json(),
|
950 |
+
"variableParent": self.variableParent.json(),
|
951 |
+
"productions": [{"program": str(e),
|
952 |
+
"arguments": [gp.json() for gp in gs ]}
|
953 |
+
for e,gs in self.library.items() ]}
|
954 |
+
|
955 |
+
@staticmethod
|
956 |
+
def fromGrammar(g):
|
957 |
+
return ContextualGrammar(g, g,
|
958 |
+
{e: [g]*len(e.infer().functionArguments())
|
959 |
+
for e in g.primitives })
|
960 |
+
|
961 |
+
|
962 |
+
class LS: # likelihood summary
|
963 |
+
def __init__(self, owner):
|
964 |
+
self.noParent = LikelihoodSummary()
|
965 |
+
self.variableParent = LikelihoodSummary()
|
966 |
+
self.library = {e: [LikelihoodSummary() for _ in gs] for e,gs in owner.library.items() }
|
967 |
+
|
968 |
+
def record(self, parent, parentIndex, actual, possibles, constant):
|
969 |
+
if parent is None: ls = self.noParent
|
970 |
+
elif parent.isIndex: ls = self.variableParent
|
971 |
+
else: ls = self.library[parent][parentIndex]
|
972 |
+
ls.record(actual, possibles, constant=constant)
|
973 |
+
|
974 |
+
def join(self, other):
|
975 |
+
self.noParent.join(other.noParent)
|
976 |
+
self.variableParent.join(other.variableParent)
|
977 |
+
for e,gs in self.library.items():
|
978 |
+
for g1,g2 in zip(gs, other.library[e]):
|
979 |
+
g1.join(g2)
|
980 |
+
|
981 |
+
def logLikelihood(self, owner):
|
982 |
+
return self.noParent.logLikelihood(owner.noParent) + \
|
983 |
+
self.variableParent.logLikelihood(owner.variableParent) + \
|
984 |
+
sum(r.logLikelihood(g)
|
985 |
+
for e, rs in self.library.items()
|
986 |
+
for r,g in zip(rs, owner.library[e]) )
|
987 |
+
def numerator(self, owner):
|
988 |
+
return self.noParent.numerator(owner.noParent) + \
|
989 |
+
self.variableParent.numerator(owner.variableParent) + \
|
990 |
+
sum(r.numerator(g)
|
991 |
+
for e, rs in self.library.items()
|
992 |
+
for r,g in zip(rs, owner.library[e]) )
|
993 |
+
def denominator(self, owner):
|
994 |
+
return self.noParent.denominator(owner.noParent) + \
|
995 |
+
self.variableParent.denominator(owner.variableParent) + \
|
996 |
+
sum(r.denominator(g)
|
997 |
+
for e, rs in self.library.items()
|
998 |
+
for r,g in zip(rs, owner.library[e]) )
|
999 |
+
|
1000 |
+
def likelihoodSummary(self, parent, parentIndex, context, environment, request, expression):
|
1001 |
+
if request.isArrow():
|
1002 |
+
assert expression.isAbstraction
|
1003 |
+
return self.likelihoodSummary(parent, parentIndex,
|
1004 |
+
context,
|
1005 |
+
[request.arguments[0]] + environment,
|
1006 |
+
request.arguments[1],
|
1007 |
+
expression.body)
|
1008 |
+
if parent is None: g = self.noParent
|
1009 |
+
elif parent.isIndex: g = self.variableParent
|
1010 |
+
else: g = self.library[parent][parentIndex]
|
1011 |
+
candidates = g.buildCandidates(request, context, environment,
|
1012 |
+
normalize=False, returnTable=True)
|
1013 |
+
|
1014 |
+
# A list of everything that would have been possible to use here
|
1015 |
+
possibles = [p for p in candidates.keys() if not p.isIndex]
|
1016 |
+
numberOfVariables = sum(p.isIndex for p in candidates.keys())
|
1017 |
+
if numberOfVariables > 0:
|
1018 |
+
possibles += [Index(0)]
|
1019 |
+
|
1020 |
+
f, xs = expression.applicationParse()
|
1021 |
+
|
1022 |
+
assert f in candidates
|
1023 |
+
|
1024 |
+
thisSummary = self.LS(self)
|
1025 |
+
thisSummary.record(parent, parentIndex,
|
1026 |
+
f, possibles,
|
1027 |
+
constant= -math.log(numberOfVariables) if f.isIndex else 0)
|
1028 |
+
|
1029 |
+
_, tp, context = candidates[f]
|
1030 |
+
argumentTypes = tp.functionArguments()
|
1031 |
+
assert len(xs) == len(argumentTypes)
|
1032 |
+
|
1033 |
+
for i, (argumentType, argument) in enumerate(zip(argumentTypes, xs)):
|
1034 |
+
argumentType = argumentType.apply(context)
|
1035 |
+
context, newSummary = self.likelihoodSummary(f, i,
|
1036 |
+
context, environment, argumentType, argument)
|
1037 |
+
thisSummary.join(newSummary)
|
1038 |
+
|
1039 |
+
return context, thisSummary
|
1040 |
+
|
1041 |
+
def closedLikelihoodSummary(self, request, expression):
|
1042 |
+
return self.likelihoodSummary(None,None,
|
1043 |
+
Context.EMPTY,[],
|
1044 |
+
request, expression)[1]
|
1045 |
+
|
1046 |
+
def logLikelihood(self, request, expression):
|
1047 |
+
return self.closedLikelihoodSummary(request, expression).logLikelihood(self)
|
1048 |
+
|
1049 |
+
def sample(self, request, maximumDepth=8, maxAttempts=None):
|
1050 |
+
attempts = 0
|
1051 |
+
while True:
|
1052 |
+
try:
|
1053 |
+
_, e = self._sample(None, None, Context.EMPTY, [], request, maximumDepth)
|
1054 |
+
return e
|
1055 |
+
except NoCandidates:
|
1056 |
+
if maxAttempts is not None:
|
1057 |
+
attempts += 1
|
1058 |
+
if attempts > maxAttempts: return None
|
1059 |
+
continue
|
1060 |
+
|
1061 |
+
def _sample(self, parent, parentIndex, context, environment, request, maximumDepth):
|
1062 |
+
if request.isArrow():
|
1063 |
+
context, body = self._sample(parent, parentIndex, context,
|
1064 |
+
[request.arguments[0]] + environment,
|
1065 |
+
request.arguments[1],
|
1066 |
+
maximumDepth)
|
1067 |
+
return context, Abstraction(body)
|
1068 |
+
if parent is None: g = self.noParent
|
1069 |
+
elif parent.isIndex: g = self.variableParent
|
1070 |
+
else: g = self.library[parent][parentIndex]
|
1071 |
+
candidates = g.buildCandidates(request, context, environment,
|
1072 |
+
normalize=True, returnProbabilities=True,
|
1073 |
+
mustBeLeaf=(maximumDepth <= 1))
|
1074 |
+
newType, chosenPrimitive, context = sampleDistribution(candidates)
|
1075 |
+
|
1076 |
+
xs = newType.functionArguments()
|
1077 |
+
returnValue = chosenPrimitive
|
1078 |
+
|
1079 |
+
for j,x in enumerate(xs):
|
1080 |
+
x = x.apply(context)
|
1081 |
+
context, x = self._sample(chosenPrimitive, j, context, environment, x, maximumDepth - 1)
|
1082 |
+
returnValue = Application(returnValue, x)
|
1083 |
+
|
1084 |
+
return context, returnValue
|
1085 |
+
|
1086 |
+
def expectedUsesMonteCarlo(self, request, debug=None):
|
1087 |
+
import numpy as np
|
1088 |
+
n = 0
|
1089 |
+
u = [0.]*len(self.primitives)
|
1090 |
+
primitives = list(sorted(self.primitives, key=str))
|
1091 |
+
noInventions = all( not p.isInvented for p in primitives )
|
1092 |
+
primitive2index = {primitive: i
|
1093 |
+
for i, primitive in enumerate(primitives)
|
1094 |
+
if primitive.isInvented or noInventions }
|
1095 |
+
eprint(primitive2index)
|
1096 |
+
ns = 10000
|
1097 |
+
with timing(f"calculated expected uses using Monte Carlo simulation w/ {ns} samples"):
|
1098 |
+
for _ in range(ns):
|
1099 |
+
p = self.sample(request, maxAttempts=0)
|
1100 |
+
if p is None: continue
|
1101 |
+
n += 1
|
1102 |
+
if debug and n < 10:
|
1103 |
+
eprint(debug, p)
|
1104 |
+
for _, child in p.walk():
|
1105 |
+
if child not in primitive2index: continue
|
1106 |
+
u[primitive2index[child]] += 1.0
|
1107 |
+
u = np.array(u)/n
|
1108 |
+
if debug:
|
1109 |
+
eprint(f"Got {n} samples. Feature vector:\n{u}")
|
1110 |
+
eprint(f"Likely used primitives: {[p for p,i in primitive2index.items() if u[i] > 0.5]}")
|
1111 |
+
eprint(f"Likely used primitive indices: {[i for p,i in primitive2index.items() if u[i] > 0.5]}")
|
1112 |
+
return u
|
1113 |
+
|
1114 |
+
def featureVector(self, _=None, requests=None, onlyInventions=True, normalize=True):
|
1115 |
+
"""
|
1116 |
+
Returns the probabilities licensed by the type system.
|
1117 |
+
This is like the grammar productions, but with irrelevant junk removed.
|
1118 |
+
Its intended use case is for clustering; it should be strictly better than the raw transition matrix.
|
1119 |
+
"""
|
1120 |
+
if requests is None:
|
1121 |
+
if self.continuationType: requests = {self.continuationType}
|
1122 |
+
elif any( 'REAL' == str(p) for p in self.primitives ): requests = set()
|
1123 |
+
elif any( 'STRING' == str(p) for p in self.primitives ): requests = {tlist(tcharacter)}
|
1124 |
+
else: requests = set()
|
1125 |
+
requests = {r.returns() for r in requests}
|
1126 |
+
features = []
|
1127 |
+
logWeights = []
|
1128 |
+
for l,t,p in sorted(self.noParent.productions,
|
1129 |
+
key=lambda z: str(z[2])):
|
1130 |
+
if onlyInventions and not p.isInvented: continue
|
1131 |
+
if any( canUnify(r, t.returns()) for r in requests ) or len(requests) == 0:
|
1132 |
+
logWeights.append(l)
|
1133 |
+
features.append(logWeights)
|
1134 |
+
for parent in sorted(self.primitives, key=str):
|
1135 |
+
if onlyInventions and not parent.isInvented: continue
|
1136 |
+
if parent not in self.library: continue
|
1137 |
+
argumentTypes = parent.infer().functionArguments()
|
1138 |
+
for j,g in enumerate(self.library[parent]):
|
1139 |
+
argumentType = argumentTypes[j]
|
1140 |
+
logWeights = []
|
1141 |
+
for l,t,p in sorted(g.productions,
|
1142 |
+
key=lambda z: str(z[2])):
|
1143 |
+
if onlyInventions and not p.isInvented: continue
|
1144 |
+
if canUnify(argumentType.returns(), t.returns()):
|
1145 |
+
logWeights.append(l)
|
1146 |
+
features.append(logWeights)
|
1147 |
+
|
1148 |
+
if normalize:
|
1149 |
+
features = [ [math.exp(w - z) for w in lw ]
|
1150 |
+
for lw in features
|
1151 |
+
if lw
|
1152 |
+
for z in [lse(lw)] ]
|
1153 |
+
import numpy as np
|
1154 |
+
return np.array([f
|
1155 |
+
for lw in features
|
1156 |
+
for f in lw])
|
1157 |
+
|
1158 |
+
def enumeration(self,context,environment,request,upperBound,
|
1159 |
+
parent=None, parentIndex=None,
|
1160 |
+
maximumDepth=20,
|
1161 |
+
lowerBound=0.):
|
1162 |
+
'''Enumerates all programs whose MDL satisfies: lowerBound <= MDL < upperBound'''
|
1163 |
+
if upperBound < 0 or maximumDepth == 1:
|
1164 |
+
return
|
1165 |
+
|
1166 |
+
if request.isArrow():
|
1167 |
+
v = request.arguments[0]
|
1168 |
+
for l, newContext, b in self.enumeration(context, [v] + environment,
|
1169 |
+
request.arguments[1],
|
1170 |
+
parent=parent, parentIndex=parentIndex,
|
1171 |
+
upperBound=upperBound,
|
1172 |
+
lowerBound=lowerBound,
|
1173 |
+
maximumDepth=maximumDepth):
|
1174 |
+
yield l, newContext, Abstraction(b)
|
1175 |
+
else:
|
1176 |
+
if parent is None: g = self.noParent
|
1177 |
+
elif parent.isIndex: g = self.variableParent
|
1178 |
+
else: g = self.library[parent][parentIndex]
|
1179 |
+
|
1180 |
+
candidates = g.buildCandidates(request, context, environment,
|
1181 |
+
normalize=True)
|
1182 |
+
|
1183 |
+
for l, t, p, newContext in candidates:
|
1184 |
+
mdl = -l
|
1185 |
+
if not (mdl < upperBound):
|
1186 |
+
continue
|
1187 |
+
|
1188 |
+
xs = t.functionArguments()
|
1189 |
+
for aL, aK, application in\
|
1190 |
+
self.enumerateApplication(newContext, environment, p, xs,
|
1191 |
+
parent=p,
|
1192 |
+
upperBound=upperBound + l,
|
1193 |
+
lowerBound=lowerBound + l,
|
1194 |
+
maximumDepth=maximumDepth - 1):
|
1195 |
+
yield aL + l, aK, application
|
1196 |
+
|
1197 |
+
def enumerateApplication(self, context, environment,
|
1198 |
+
function, argumentRequests,
|
1199 |
+
# Upper bound on the description length of all of
|
1200 |
+
# the arguments
|
1201 |
+
upperBound,
|
1202 |
+
# Lower bound on the description length of all of
|
1203 |
+
# the arguments
|
1204 |
+
lowerBound=0.,
|
1205 |
+
maximumDepth=20,
|
1206 |
+
parent=None,
|
1207 |
+
originalFunction=None,
|
1208 |
+
argumentIndex=0):
|
1209 |
+
assert parent is not None
|
1210 |
+
if upperBound < 0. or maximumDepth == 1:
|
1211 |
+
return
|
1212 |
+
if originalFunction is None:
|
1213 |
+
originalFunction = function
|
1214 |
+
|
1215 |
+
if argumentRequests == []:
|
1216 |
+
if lowerBound <= 0. and 0. < upperBound:
|
1217 |
+
yield 0., context, function
|
1218 |
+
else:
|
1219 |
+
return
|
1220 |
+
else:
|
1221 |
+
argRequest = argumentRequests[0].apply(context)
|
1222 |
+
laterRequests = argumentRequests[1:]
|
1223 |
+
for argL, newContext, arg in self.enumeration(context, environment, argRequest,
|
1224 |
+
parent=parent, parentIndex=argumentIndex,
|
1225 |
+
upperBound=upperBound,
|
1226 |
+
lowerBound=0.,
|
1227 |
+
maximumDepth=maximumDepth):
|
1228 |
+
if violatesSymmetry(originalFunction, arg, argumentIndex):
|
1229 |
+
continue
|
1230 |
+
|
1231 |
+
newFunction = Application(function, arg)
|
1232 |
+
for resultL, resultK, result in self.enumerateApplication(newContext, environment, newFunction,
|
1233 |
+
laterRequests,
|
1234 |
+
parent=parent,
|
1235 |
+
upperBound=upperBound + argL,
|
1236 |
+
lowerBound=lowerBound + argL,
|
1237 |
+
maximumDepth=maximumDepth,
|
1238 |
+
originalFunction=originalFunction,
|
1239 |
+
argumentIndex=argumentIndex + 1):
|
1240 |
+
yield resultL + argL, resultK, result
|
1241 |
+
|
1242 |
+
|
1243 |
+
|
1244 |
+
|
1245 |
+
def violatesSymmetry(f, x, argumentIndex):
|
1246 |
+
if not f.isPrimitive:
|
1247 |
+
return False
|
1248 |
+
while x.isApplication:
|
1249 |
+
x = x.f
|
1250 |
+
if not x.isPrimitive:
|
1251 |
+
return False
|
1252 |
+
f = f.name
|
1253 |
+
x = x.name
|
1254 |
+
if f == "car":
|
1255 |
+
return x == "cons" or x == "empty"
|
1256 |
+
if f == "cdr":
|
1257 |
+
return x == "cons" or x == "empty"
|
1258 |
+
if f == "+":
|
1259 |
+
return x == "0" or (argumentIndex == 1 and x == "+")
|
1260 |
+
if f == "-":
|
1261 |
+
return argumentIndex == 1 and x == "0"
|
1262 |
+
if f == "empty?":
|
1263 |
+
return x == "cons" or x == "empty"
|
1264 |
+
if f == "zero?":
|
1265 |
+
return x == "0" or x == "1"
|
1266 |
+
if f == "index" or f == "map" or f == "zip":
|
1267 |
+
return x == "empty"
|
1268 |
+
if f == "range":
|
1269 |
+
return x == "0"
|
1270 |
+
if f == "fold":
|
1271 |
+
return argumentIndex == 1 and x == "empty"
|
1272 |
+
return False
|
1273 |
+
|
1274 |
+
def batchLikelihood(jobs):
|
1275 |
+
"""Takes as input a set of (program, request, grammar) and returns a dictionary mapping each of these to its likelihood under the grammar"""
|
1276 |
+
superGrammar = Grammar.uniform(list({p for _1,_2,g in jobs for p in g.primitives}),
|
1277 |
+
continuationType=list(jobs)[0][-1].continuationType)
|
1278 |
+
programsAndRequests = {(program, request)
|
1279 |
+
for program, request, grammar in jobs}
|
1280 |
+
with timing(f"Calculated {len(programsAndRequests)} likelihood summaries"):
|
1281 |
+
summary = {(program, request): superGrammar.closedLikelihoodSummary(request, program)
|
1282 |
+
for program, request in programsAndRequests}
|
1283 |
+
with timing(f"Calculated log likelihoods from summaries"):
|
1284 |
+
response = {}
|
1285 |
+
for program, request, grammar in jobs:
|
1286 |
+
fast = summary[(program, request)].logLikelihood_overlyGeneral(grammar)
|
1287 |
+
if False: # debugging
|
1288 |
+
slow = grammar.logLikelihood(request, program)
|
1289 |
+
print(program)
|
1290 |
+
eprint(grammar.closedLikelihoodSummary(request, program))
|
1291 |
+
eprint(superGrammar.closedLikelihoodSummary(request, program))
|
1292 |
+
print()
|
1293 |
+
assert abs(fast - slow) < 0.0001
|
1294 |
+
response[(program, request, grammar)] = fast
|
1295 |
+
return response
|
1296 |
+
|
1297 |
+
if __name__ == "__main__":
|
1298 |
+
from dreamcoder.domains.arithmetic.arithmeticPrimitives import *
|
1299 |
+
g = ContextualGrammar.fromGrammar(Grammar.uniform([k0,k1,addition, subtraction]))
|
1300 |
+
g = g.randomWeights(lambda *a: random.random())
|
1301 |
+
#p = Program.parse("(lambda (+ 1 $0))")
|
1302 |
+
request = arrow(tint,tint)
|
1303 |
+
for ll,_,p in g.enumeration(Context.EMPTY,[],request,
|
1304 |
+
12.):
|
1305 |
+
ll_ = g.logLikelihood(request,p)
|
1306 |
+
print(ll,p,ll_)
|
1307 |
+
d = abs(ll - ll_)
|
1308 |
+
assert d < 0.0001
|
dreamcoder/likelihoodModel.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.task import Task, EvaluationTimeout
|
2 |
+
import gc
|
3 |
+
from dreamcoder.utilities import *
|
4 |
+
from collections import Counter
|
5 |
+
import math
|
6 |
+
|
7 |
+
from dreamcoder.domains.regex.groundtruthRegexes import gt_dict
|
8 |
+
|
9 |
+
gt_dict = {"Data column no. "+str(num): r_str for num, r_str in gt_dict.items()}
|
10 |
+
|
11 |
+
class AllOrNothingLikelihoodModel:
|
12 |
+
def __init__(self, timeout=None):
|
13 |
+
self.timeout = timeout
|
14 |
+
|
15 |
+
def score(self, program, task):
|
16 |
+
logLikelihood = task.logLikelihood(program, self.timeout)
|
17 |
+
return valid(logLikelihood), logLikelihood
|
18 |
+
|
19 |
+
|
20 |
+
class EuclideanLikelihoodModel:
|
21 |
+
"""Likelihood is based on Euclidean distance between features"""
|
22 |
+
|
23 |
+
def __init__(self, featureExtractor, successCutoff=0.9):
|
24 |
+
self.extract = featureExtractor
|
25 |
+
self.successCutoff = successCutoff
|
26 |
+
|
27 |
+
def score(self, program, task):
|
28 |
+
taskFeat = self.extract.featuresOfTask(task)
|
29 |
+
progFeat = self.extract.featuresOfProgram(program, task.request)
|
30 |
+
assert len(taskFeat) == len(progFeat)
|
31 |
+
distance = sum((x1 - x2)**2 for x1, x2 in zip(taskFeat, progFeat))
|
32 |
+
logLikelihood = float(-distance) # FIXME: this is really naive
|
33 |
+
return exp(logLikelihood) > self.successCutoff, logLikelihood
|
34 |
+
|
35 |
+
def longest_common_substr(arr):
|
36 |
+
#array of examples
|
37 |
+
|
38 |
+
# Python 3 program to find the stem
|
39 |
+
# of given list of words
|
40 |
+
# function to find the stem (longest
|
41 |
+
# common substring) from the string array
|
42 |
+
# Determine size of the array
|
43 |
+
n = len(arr)
|
44 |
+
|
45 |
+
# Take first word from array
|
46 |
+
# as reference
|
47 |
+
s = arr[0]
|
48 |
+
l = len(s)
|
49 |
+
res = ""
|
50 |
+
for i in range(l) :
|
51 |
+
for j in range( i + 1, l + 1) :
|
52 |
+
# generating all possible substrings
|
53 |
+
# of our reference string arr[0] i.e s
|
54 |
+
stem = s[i:j]
|
55 |
+
k = 1
|
56 |
+
for k in range(1, n):
|
57 |
+
|
58 |
+
# Check if the generated stem is
|
59 |
+
# common to to all words
|
60 |
+
if stem not in arr[k]:
|
61 |
+
break
|
62 |
+
|
63 |
+
# If current substring is present in
|
64 |
+
# all strings and its length is greater
|
65 |
+
# than current result
|
66 |
+
if (k + 1 == n and len(res) < len(stem)): res = stem
|
67 |
+
return res
|
68 |
+
|
69 |
+
def add_string_constants(tasks):
|
70 |
+
for task in tasks:
|
71 |
+
task.str_const = longest_common_substr([example[1] for example in task.examples])
|
72 |
+
return tasks
|
73 |
+
|
74 |
+
def get_gt_ll(name, examples):
|
75 |
+
#gets groundtruth from dict
|
76 |
+
import pregex as pre
|
77 |
+
r_str = gt_dict[name]
|
78 |
+
preg = pre.create(r_str)
|
79 |
+
|
80 |
+
if type(examples[0]) == list:
|
81 |
+
examples = [ "".join(example) for example in examples]
|
82 |
+
|
83 |
+
s = sum( preg.match(example) for example in examples)
|
84 |
+
if s == float("-inf"):
|
85 |
+
print("bad for ", name)
|
86 |
+
print('preg:', preg)
|
87 |
+
print('preg sample:', [preg.sample() for i in range(3)])
|
88 |
+
print("exs", examples)
|
89 |
+
#assert False
|
90 |
+
return s
|
91 |
+
|
92 |
+
|
93 |
+
def add_cutoff_values(tasks, ll_cutoff):
|
94 |
+
from dreamcoder.domains.regex.makeRegexTasks import makeNewTasks
|
95 |
+
if ll_cutoff is None or ll_cutoff == "None":
|
96 |
+
for task in tasks:
|
97 |
+
task.ll_cutoff = None
|
98 |
+
return tasks
|
99 |
+
if ll_cutoff == "gt":
|
100 |
+
from dreamcoder.domains.regex.makeRegexTasks import regexHeldOutExamples
|
101 |
+
for task in tasks:
|
102 |
+
task.ll_cutoff = None
|
103 |
+
task.gt = get_gt_ll(task.name, [example[1] for example in task.examples])
|
104 |
+
task.gt_test = get_gt_ll(task.name,
|
105 |
+
[example[1] for example in regexHeldOutExamples(task) ])
|
106 |
+
return tasks
|
107 |
+
elif ll_cutoff == "plus":
|
108 |
+
for task in tasks:
|
109 |
+
task.ll_cutoff = regex_plus_bound([example[1] for example in task.examples])
|
110 |
+
return tasks
|
111 |
+
elif ll_cutoff == "bigram":
|
112 |
+
eprint("WARNING: using entire corpus to make bigram model")
|
113 |
+
#this means i do it twice, which is eh whatever
|
114 |
+
model = make_corpus_bigram(show_tasks(makeNewTasks()))
|
115 |
+
for task in tasks:
|
116 |
+
task.ll_cutoff = bigram_corpus_score([example[1] for example in task.examples], model)
|
117 |
+
return tasks
|
118 |
+
elif ll_cutoff =="unigram":
|
119 |
+
eprint("WARNING: using entire corpus to make unigram model")
|
120 |
+
#this means i do it twice, which is eh whatever
|
121 |
+
model = make_corpus_unigram(show_tasks(makeNewTasks()))
|
122 |
+
for task in tasks:
|
123 |
+
task.ll_cutoff = unigram_corpus_score([example[1] for example in task.examples], model)
|
124 |
+
return tasks
|
125 |
+
elif ll_cutoff =="mix":
|
126 |
+
eprint("WARNING: using entire corpus to make bigram model")
|
127 |
+
eprint("WARNING: using entire corpus to make unigram model")
|
128 |
+
#this means i do it twice, which is eh whatever
|
129 |
+
unigram = make_corpus_unigram(show_tasks(makeNewTasks()))
|
130 |
+
bigram = make_corpus_bigram(show_tasks(makeNewTasks()))
|
131 |
+
for task in tasks:
|
132 |
+
uniscore = unigram_corpus_score([example[1] for example in task.examples], unigram)
|
133 |
+
biscore = bigram_corpus_score([example[1] for example in task.examples], bigram)
|
134 |
+
task.ll_cutoff = math.log(0.75*math.exp(biscore) + 0.25*math.exp(uniscore))
|
135 |
+
return tasks
|
136 |
+
else:
|
137 |
+
eprint("not implemented")
|
138 |
+
eprint("cutoff val:")
|
139 |
+
eprint(ll_cutoff)
|
140 |
+
assert False
|
141 |
+
|
142 |
+
def show_tasks(dataset):
|
143 |
+
task_list = []
|
144 |
+
for task in dataset:
|
145 |
+
task_list.append([example[1] for example in task.examples])
|
146 |
+
return task_list
|
147 |
+
|
148 |
+
def regex_plus_bound(X):
|
149 |
+
from pregex import pregex
|
150 |
+
c = Counter(X)
|
151 |
+
regexes = [
|
152 |
+
pregex.create(".+"),
|
153 |
+
pregex.create("\d+"),
|
154 |
+
pregex.create("\w+"),
|
155 |
+
pregex.create("\s+"),
|
156 |
+
pregex.create("\\u+"),
|
157 |
+
pregex.create("\l+")]
|
158 |
+
regex_scores = []
|
159 |
+
for r in regexes:
|
160 |
+
regex_scores.append(sum(c[x] * r.match(x) for x in c)/float(sum([len(x) for x in X])) )
|
161 |
+
return max(regex_scores)
|
162 |
+
|
163 |
+
|
164 |
+
def make_corpus_unigram(C):
|
165 |
+
str_list = [example + '\n' for task in C for example in task]
|
166 |
+
c = Counter(char for example in str_list for char in example )
|
167 |
+
n = sum(c.values())
|
168 |
+
|
169 |
+
logp = {x:math.log(c[x]/n) for x in c}
|
170 |
+
return logp
|
171 |
+
|
172 |
+
def unigram_corpus_score(X, logp):
|
173 |
+
task_ll = 0
|
174 |
+
for x in X:
|
175 |
+
x = x + '\n'
|
176 |
+
task_ll += sum( logp.get(c, float('-inf')) for c in x)/len(x)
|
177 |
+
|
178 |
+
ll = task_ll/len(X)
|
179 |
+
return ll
|
180 |
+
|
181 |
+
def unigram_task_score(X):
|
182 |
+
"""
|
183 |
+
Given a list of strings, X, calculate the maximum log-likelihood per character for a unigram model over characters (including STOP symbol)
|
184 |
+
"""
|
185 |
+
c = Counter(x for s in X for x in s)
|
186 |
+
c.update("end" for s in X)
|
187 |
+
n = sum(c.values())
|
188 |
+
logp = {x:math.log(c[x]/n) for x in c}
|
189 |
+
return sum(c[x]*logp[x] for x in c)/n
|
190 |
+
|
191 |
+
def make_corpus_bigram(C):
|
192 |
+
#using newline as "end"
|
193 |
+
#C is a list of tasks
|
194 |
+
|
195 |
+
#make one big list of strings
|
196 |
+
str_list = [example + '\n' for task in C for example in task]
|
197 |
+
|
198 |
+
#make list of
|
199 |
+
head_count = Counter(element[0] for element in str_list)
|
200 |
+
head_n = sum(head_count.values())
|
201 |
+
head_logp = {x:math.log(head_count[x]/head_n) for x in head_count}
|
202 |
+
|
203 |
+
body_count = Counter(element[i:i+2] for element in str_list for i in range(len(element)-1))
|
204 |
+
body_bigram_n = sum(body_count.values())
|
205 |
+
#body_count/body_bigram_n gives the joint of a bigram
|
206 |
+
body_character_n = Counter(char for element in str_list for char in element)
|
207 |
+
body_unigram_n = sum(body_character_n.values())
|
208 |
+
|
209 |
+
body_logp = {x:math.log(body_count[x] / body_bigram_n / body_character_n[x[0]] * body_unigram_n) for x in body_count}
|
210 |
+
|
211 |
+
return {**head_logp, **body_logp}
|
212 |
+
|
213 |
+
def bigram_corpus_score(X, logp):
|
214 |
+
#assume you have a logp dict
|
215 |
+
task_ll = 0
|
216 |
+
for x in X:
|
217 |
+
bigram_list = [x[0]] + [x[i:i+2] for i in range(len(x)-1)] + [x[-1] + '\n']
|
218 |
+
bigram_list = [ ''.join(b) if isinstance(b,list) else b
|
219 |
+
for b in bigram_list ]
|
220 |
+
|
221 |
+
string_ll = sum(logp.get(bigram, float('-inf')) for bigram in bigram_list) #/(len(x) + 1)
|
222 |
+
|
223 |
+
task_ll += string_ll
|
224 |
+
|
225 |
+
ll = task_ll #/len(X)
|
226 |
+
return ll
|
227 |
+
|
228 |
+
|
229 |
+
class ProbabilisticLikelihoodModel:
|
230 |
+
|
231 |
+
def __init__(self, timeout):
|
232 |
+
self.timeout = timeout
|
233 |
+
# i need timeout
|
234 |
+
|
235 |
+
def score(self, program, task):
|
236 |
+
# need a try, catch here for problems, and for timeouts
|
237 |
+
# can copy task.py for the timeout structure
|
238 |
+
try:
|
239 |
+
def timeoutCallBack(_1, _2): raise EvaluationTimeout()
|
240 |
+
signal.signal(signal.SIGVTALRM, timeoutCallBack)
|
241 |
+
signal.setitimer(signal.ITIMER_VIRTUAL, self.timeout)
|
242 |
+
try:
|
243 |
+
string_pregex = program.evaluate([])
|
244 |
+
# if 'left_paren' in program.show(False):
|
245 |
+
#eprint("string_pregex:", string_pregex)
|
246 |
+
#eprint("string_pregex:", string_pregex)
|
247 |
+
preg = string_pregex # pregex.create(string_pregex)
|
248 |
+
except IndexError:
|
249 |
+
# free variable
|
250 |
+
return False, NEGATIVEINFINITY
|
251 |
+
except Exception as e:
|
252 |
+
eprint("Exception during evaluation:", e)
|
253 |
+
if "Attempt to evaluate fragment variable" in e:
|
254 |
+
eprint("program (bc fragment error)", program)
|
255 |
+
return False, NEGATIVEINFINITY
|
256 |
+
|
257 |
+
#tries and catches
|
258 |
+
|
259 |
+
# include prior somehow
|
260 |
+
# right now, just summing up log likelihoods. IDK if this is correct.
|
261 |
+
# also not using prior at all.
|
262 |
+
|
263 |
+
cum_ll = 0
|
264 |
+
|
265 |
+
example_list = [example[1] for example in task.examples]
|
266 |
+
c_example_list = Counter(example_list)
|
267 |
+
|
268 |
+
for c_example in c_example_list:
|
269 |
+
#might want a try, except around the following line:
|
270 |
+
|
271 |
+
try:
|
272 |
+
#eprint("about to match", program)
|
273 |
+
#print("preg:", preg)
|
274 |
+
ll = preg.match(c_example)
|
275 |
+
#eprint("completed match", ll, program)
|
276 |
+
except ValueError as e:
|
277 |
+
eprint("ValueError:", e)
|
278 |
+
ll = float('-inf')
|
279 |
+
|
280 |
+
#eprint("pregex:", string_pregex)
|
281 |
+
#eprint("example[1]", example[1])
|
282 |
+
|
283 |
+
if ll == float('-inf'):
|
284 |
+
return False, NEGATIVEINFINITY
|
285 |
+
else:
|
286 |
+
#ll_per_char = ll/float(len(example[1]))
|
287 |
+
#cum_ll_per_char += ll_per_char
|
288 |
+
|
289 |
+
cum_ll += c_example_list[c_example] * ll
|
290 |
+
|
291 |
+
#normalized_cum_ll_per_char = cum_ll_per_char/float(len(task.examples))
|
292 |
+
#avg_char_num = sum([len(example[1]) for example in task.examples])/float(len(task.examples))
|
293 |
+
|
294 |
+
#cutoff_ll = regex_plus_bound(example_list)
|
295 |
+
|
296 |
+
normalized_cum_ll = cum_ll/ float(sum([len(example) for example in example_list]))
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
#TODO: change the way normalized_cum_ll is calculated
|
301 |
+
#TODO: refactor to pass in bigram_model, and others
|
302 |
+
#TODO: refactor to do 95% certainty thing josh wants
|
303 |
+
success = normalized_cum_ll > task.ll_cutoff
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
#eprint("cutoff_ll:", cutoff_ll, ", norm_cum_ll:", normalized_cum_ll)
|
308 |
+
|
309 |
+
return success, normalized_cum_ll
|
310 |
+
|
311 |
+
except EvaluationTimeout:
|
312 |
+
eprint("Timed out while evaluating", program)
|
313 |
+
return False, NEGATIVEINFINITY
|
314 |
+
finally:
|
315 |
+
signal.signal(signal.SIGVTALRM, lambda *_: None)
|
316 |
+
signal.setitimer(signal.ITIMER_VIRTUAL, 0)
|
317 |
+
|
318 |
+
|
319 |
+
try:
|
320 |
+
import torch
|
321 |
+
import torch.nn as nn
|
322 |
+
import torch.nn.functional as F
|
323 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
324 |
+
from torch.autograd import Variable
|
325 |
+
|
326 |
+
class FeatureDiscriminatorLikelihoodModel(nn.Module):
|
327 |
+
def __init__(self, tasks, featureExtractor,
|
328 |
+
successCutoff=0.6, H=8, trainingSuccessRatio=0.5):
|
329 |
+
super(FeatureDiscriminatorLikelihoodModel, self).__init__()
|
330 |
+
self.extract = featureExtractor
|
331 |
+
self.successCutoff = successCutoff
|
332 |
+
self.trainingSuccessRatio = trainingSuccessRatio
|
333 |
+
|
334 |
+
self.W = nn.Linear(featureExtractor.outputDimensionality, H)
|
335 |
+
self.output = nn.Linear(H, 1)
|
336 |
+
|
337 |
+
# training on initialization
|
338 |
+
self.train(tasks)
|
339 |
+
|
340 |
+
def forward(self, examples):
|
341 |
+
"""
|
342 |
+
Examples is a list of feature sets corresponding to a particular example.
|
343 |
+
Output in [0,1] whether all examples correspond to the same program
|
344 |
+
"""
|
345 |
+
assert all(
|
346 |
+
len(x) == self.extract.outputDimensionality for x in examples)
|
347 |
+
examples = [F.tanh(self.W(ex)) for ex in examples]
|
348 |
+
maxed, _ = torch.max(torch.stack(examples), dim=0)
|
349 |
+
return F.sigmoid(self.output(maxed))
|
350 |
+
|
351 |
+
def train(self, tasks, steps=400):
|
352 |
+
# list of list of features for each example in each task
|
353 |
+
optimizer = torch.optim.Adam(self.parameters())
|
354 |
+
with timing("Trained discriminator"):
|
355 |
+
losses = []
|
356 |
+
for i in range(steps):
|
357 |
+
self.zero_grad()
|
358 |
+
if random.random() <= self.trainingSuccessRatio:
|
359 |
+
# success
|
360 |
+
t = random.choice(tasks)
|
361 |
+
features = [self.extract.featuresOfTask(
|
362 |
+
Task(t.name, t.request, [ex], t.features))
|
363 |
+
for ex in t.examples]
|
364 |
+
loss = (self(features) - 1.0)**2
|
365 |
+
else:
|
366 |
+
# fail
|
367 |
+
t1, t2 = random.sample(tasks, 2)
|
368 |
+
features1 = [self.extract.featuresOfTask(
|
369 |
+
Task(t1.name, t1.request, [ex], t1.features))
|
370 |
+
for ex in t1.examples[:len(t1.examples) / 2]]
|
371 |
+
features2 = [self.extract.featuresOfTask(
|
372 |
+
Task(t2.name, t2.request, [ex], t2.features))
|
373 |
+
for ex in t2.examples[len(t2.examples) / 2:]]
|
374 |
+
features = features1 + features2
|
375 |
+
loss = self(features)**2
|
376 |
+
|
377 |
+
loss.backward()
|
378 |
+
optimizer.step()
|
379 |
+
losses.append(loss.data[0])
|
380 |
+
if not i % 50:
|
381 |
+
eprint(
|
382 |
+
"Discriminator Epoch",
|
383 |
+
i,
|
384 |
+
"Loss",
|
385 |
+
sum(losses) /
|
386 |
+
len(losses))
|
387 |
+
gc.collect()
|
388 |
+
|
389 |
+
def score(self, program, task):
|
390 |
+
taskFeatures = self.extract.featuresOfTask(task)
|
391 |
+
progFeatures = self.extract.featuresOfProgram(
|
392 |
+
program, task.request)
|
393 |
+
likelihood = self([taskFeatures] + [progFeatures])
|
394 |
+
likelihood = float(likelihood)
|
395 |
+
return likelihood > self.successCutoff, log(likelihood)
|
396 |
+
except ImportError:
|
397 |
+
pass
|
398 |
+
|
399 |
+
|
400 |
+
if __name__=="__main__":
|
401 |
+
|
402 |
+
arr = ['MAM.OSBS.2014.06', 'MAM.OSBS.2013.07', 'MAM.OSBS.2013.09', 'MAM.OSBS.2014.05', 'MAM.OSBS.2014.11']
|
403 |
+
stems = longest_common_substr(arr)
|
404 |
+
print(stems)
|
405 |
+
|
406 |
+
|
407 |
+
|
dreamcoder/primitiveGraph.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.program import *
|
2 |
+
|
3 |
+
def graphPrimitives(result, prefix, view=False):
|
4 |
+
try:
|
5 |
+
from graphviz import Digraph
|
6 |
+
except:
|
7 |
+
eprint("You are missing the graphviz library - cannot graph primitives!")
|
8 |
+
return
|
9 |
+
|
10 |
+
|
11 |
+
primitives = { p
|
12 |
+
for g in result.grammars
|
13 |
+
for p in g.primitives
|
14 |
+
if p.isInvented }
|
15 |
+
age = {p: min(j for j,g in enumerate(result.grammars) if p in g.primitives) + 1
|
16 |
+
for p in primitives }
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
ages = set(age.values())
|
21 |
+
age2primitives = {a: {p for p,ap in age.items() if a == ap }
|
22 |
+
for a in ages}
|
23 |
+
|
24 |
+
def lb(s,T=20):
|
25 |
+
s = s.split()
|
26 |
+
l = []
|
27 |
+
n = 0
|
28 |
+
for w in s:
|
29 |
+
if n + len(w) > T:
|
30 |
+
l.append("<br />")
|
31 |
+
n = 0
|
32 |
+
n += len(w)
|
33 |
+
l.append(w)
|
34 |
+
return " ".join(l)
|
35 |
+
|
36 |
+
nameSimplification = {
|
37 |
+
"fix1": 'Y',
|
38 |
+
"tower_loopM": "for",
|
39 |
+
"tower_embed": "get/set",
|
40 |
+
"moveHand": "move",
|
41 |
+
"reverseHand": "reverse",
|
42 |
+
"logo_DIVA": '/',
|
43 |
+
"logo_epsA": 'ε',
|
44 |
+
"logo_epsL": 'ε',
|
45 |
+
"logo_IFTY": '∞',
|
46 |
+
"logo_forLoop": "for",
|
47 |
+
"logo_UA": "2π",
|
48 |
+
"logo_FWRT": "move",
|
49 |
+
"logo_UL": "1",
|
50 |
+
"logo_SUBA": "-",
|
51 |
+
"logo_ZL": "0",
|
52 |
+
"logo_ZA": "0",
|
53 |
+
"logo_MULL": "*",
|
54 |
+
"logo_MULA": "*",
|
55 |
+
"logo_PT": "pen-up",
|
56 |
+
"logo_GETSET": "get/set"
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
name = {}
|
61 |
+
simplification = {}
|
62 |
+
depth = {}
|
63 |
+
def getName(p):
|
64 |
+
if p in name: return name[p]
|
65 |
+
children = {k: getName(k)
|
66 |
+
for _,k in p.body.walk()
|
67 |
+
if k.isInvented}
|
68 |
+
simplification_ = p.body
|
69 |
+
for k,childName in children.items():
|
70 |
+
simplification_ = simplification_.substitute(k, Primitive(childName,None,None))
|
71 |
+
for original, simplified in nameSimplification.items():
|
72 |
+
simplification_ = simplification_.substitute(Primitive(original,None,None),
|
73 |
+
Primitive(simplified,None,None))
|
74 |
+
name[p] = "f%d"%len(name)
|
75 |
+
simplification[p] = name[p] + '=' + lb(prettyProgram(simplification_, Lisp=True))
|
76 |
+
depth[p] = 1 + max([depth[k] for k in children] + [0])
|
77 |
+
return name[p]
|
78 |
+
|
79 |
+
for p in primitives:
|
80 |
+
getName(p)
|
81 |
+
|
82 |
+
depths = {depth[p] for p in primitives}
|
83 |
+
depth2primitives = {d: {p for p in primitives if depth[p] == d }
|
84 |
+
for d in depths}
|
85 |
+
|
86 |
+
englishDescriptions = {"#(lambda (lambda (map (lambda (index $0 $2)) (range $0))))":
|
87 |
+
"Prefix",
|
88 |
+
"#(lambda (lambda (fold $0 $1 (lambda (lambda (cons $1 $0))))))":
|
89 |
+
"Append",
|
90 |
+
"#(lambda (cons LPAREN (#(lambda (lambda (fold $0 $1 (lambda (lambda (cons $1 $0)))))) (cons RPAREN empty) $0)))":
|
91 |
+
"Enclose w/ parens",
|
92 |
+
"#(lambda (unfold $0 (lambda (empty? $0)) (lambda (car $0)) (lambda (#(lambda (lambda (fold $1 $1 (lambda (lambda (cdr (if (char-eq? $1 $2) $3 $0))))))) $0 SPACE))))":
|
93 |
+
"Abbreviate",
|
94 |
+
"#(lambda (lambda (fold $1 $1 (lambda (lambda (cdr (if (char-eq? $1 $2) $3 $0)))))))":
|
95 |
+
"Drop until char",
|
96 |
+
"#(lambda (lambda (fold $1 $1 (lambda (lambda (if (char-eq? $1 $2) empty (cons $1 $0)))))))":
|
97 |
+
"Take until char",
|
98 |
+
"#(lambda (lambda (#(lambda (lambda (fold $0 $1 (lambda (lambda (cons $1 $0)))))) (cons $0 $1))))":
|
99 |
+
"Append char",
|
100 |
+
"#(lambda (lambda (map (lambda (if (char-eq? $0 $1) $2 $0)))))":
|
101 |
+
"Substitute char",
|
102 |
+
"#(lambda (lambda (length (unfold $1 (lambda (char-eq? (car $0) $1)) (lambda ',') (lambda (cdr $0))))))":
|
103 |
+
"Index of char",
|
104 |
+
"#(lambda (#(lambda (lambda (fold $0 $1 (lambda (lambda (cons $1 $0)))))) $0 STRING))":
|
105 |
+
"Append const",
|
106 |
+
"#(lambda (lambda (fold $1 $1 (lambda (lambda (fold $0 $0 (lambda (lambda (cdr (if (char-eq? $1 $4) $0 (cons $1 $0)))))))))))":
|
107 |
+
"Last word",
|
108 |
+
"#(lambda (lambda (cons (car $1) (cons '.' (cons (car $0) (cons '.' empty))))))":
|
109 |
+
"Abbreviate name",
|
110 |
+
"#(lambda (lambda (cons (car $1) (cons $0 empty))))":
|
111 |
+
"First char+char",
|
112 |
+
"#(lambda (#(lambda (lambda (fold $0 $1 (lambda (lambda (cons $1 $0)))))) (#(lambda (lambda (fold $1 $1 (lambda (lambda (fold $0 $0 (lambda (lambda (cdr (if (char-eq? $1 $4) $0 (cons $1 $0))))))))))) STRING (index (length (cdr $0)) $0)) $0))":
|
113 |
+
"Ensure suffix"
|
114 |
+
|
115 |
+
}
|
116 |
+
|
117 |
+
def makeUnorderedGraph(fn):
|
118 |
+
g = Digraph()
|
119 |
+
g.graph_attr['rankdir'] = 'LR'
|
120 |
+
|
121 |
+
for p in primitives:
|
122 |
+
g.node(getName(p),
|
123 |
+
label="<%s>"%simplification[p])
|
124 |
+
for p in primitives:
|
125 |
+
children = {k
|
126 |
+
for _,k in p.body.walk()
|
127 |
+
if k.isInvented}
|
128 |
+
for k in children:
|
129 |
+
g.edge(name[k],name[p])
|
130 |
+
try:
|
131 |
+
g.render(fn,view=view)
|
132 |
+
eprint("Exported primitive graph to",fn)
|
133 |
+
except:
|
134 |
+
eprint("Got some kind of error while trying to render primitive graph! Did you install graphviz/dot?")
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
def makeGraph(ordering, fn):
|
139 |
+
g = Digraph()
|
140 |
+
g.graph_attr['rankdir'] = 'RL'
|
141 |
+
|
142 |
+
if False:
|
143 |
+
with g.subgraph(name='cluster_0') as sg:
|
144 |
+
sg.graph_attr['rank'] = 'same'
|
145 |
+
sg.attr(label='Primitives')
|
146 |
+
for j, primitive in enumerate(result.grammars[-1].primitives):
|
147 |
+
if primitive.isInvented: continue
|
148 |
+
sg.node("primitive%d"%j, label=str(primitive))
|
149 |
+
|
150 |
+
for o in sorted(ordering.keys()):
|
151 |
+
with g.subgraph(name='cluster_%d'%o) as sg:
|
152 |
+
sg.graph_attr['rank'] = 'same'
|
153 |
+
#sg.attr(label='Depth %d'%o)
|
154 |
+
for p in ordering[o]:
|
155 |
+
if str(p) in englishDescriptions:
|
156 |
+
thisLabel = '<<font face="boldfontname"><u>%s</u></font><br />%s>'%(englishDescriptions[str(p)],simplification[p])
|
157 |
+
else:
|
158 |
+
eprint("WARNING: Do not have an English description of:\n",p)
|
159 |
+
eprint()
|
160 |
+
thisLabel = "<%s>"%simplification[p]
|
161 |
+
sg.node(getName(p),
|
162 |
+
label=thisLabel)
|
163 |
+
|
164 |
+
for p in ordering[o]:
|
165 |
+
children = {k
|
166 |
+
for _,k in p.body.walk()
|
167 |
+
if k.isInvented}
|
168 |
+
for k in children:
|
169 |
+
g.edge(name[k],name[p])
|
170 |
+
|
171 |
+
eprint("Exporting primitive graph to",fn)
|
172 |
+
try:
|
173 |
+
g.render(fn,view=view)
|
174 |
+
except Exception as e:
|
175 |
+
eprint("Got some kind of error while trying to render primitive graph! Did you install graphviz/dot?")
|
176 |
+
print(e)
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
makeGraph(depth2primitives,prefix+'depth.pdf')
|
181 |
+
makeUnorderedGraph(prefix+'unordered.pdf')
|
182 |
+
#makeGraph(age2primitives,prefix+'iter.pdf')
|
dreamcoder/program.py
ADDED
@@ -0,0 +1,1214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from dreamcoder.type import *
|
4 |
+
from dreamcoder.utilities import *
|
5 |
+
|
6 |
+
from time import time
|
7 |
+
import math
|
8 |
+
|
9 |
+
|
10 |
+
class InferenceFailure(Exception):
|
11 |
+
pass
|
12 |
+
|
13 |
+
|
14 |
+
class ShiftFailure(Exception):
|
15 |
+
pass
|
16 |
+
|
17 |
+
class RunFailure(Exception):
|
18 |
+
pass
|
19 |
+
|
20 |
+
|
21 |
+
class Program(object):
|
22 |
+
def __repr__(self): return str(self)
|
23 |
+
|
24 |
+
def __ne__(self, o): return not (self == o)
|
25 |
+
|
26 |
+
def __str__(self): return self.show(False)
|
27 |
+
|
28 |
+
def canHaveType(self, t):
|
29 |
+
try:
|
30 |
+
context, actualType = self.inferType(Context.EMPTY, [], {})
|
31 |
+
context, t = t.instantiate(context)
|
32 |
+
context.unify(t, actualType)
|
33 |
+
return True
|
34 |
+
except UnificationFailure as e:
|
35 |
+
return False
|
36 |
+
|
37 |
+
def betaNormalForm(self):
|
38 |
+
n = self
|
39 |
+
while True:
|
40 |
+
np = n.betaReduce()
|
41 |
+
if np is None: return n
|
42 |
+
n = np
|
43 |
+
|
44 |
+
def infer(self):
|
45 |
+
try:
|
46 |
+
return self.inferType(Context.EMPTY, [], {})[1].canonical()
|
47 |
+
except UnificationFailure as e:
|
48 |
+
raise InferenceFailure(self, e)
|
49 |
+
|
50 |
+
def uncurry(self):
|
51 |
+
t = self.infer()
|
52 |
+
a = len(t.functionArguments())
|
53 |
+
e = self
|
54 |
+
existingAbstractions = 0
|
55 |
+
while e.isAbstraction:
|
56 |
+
e = e.body
|
57 |
+
existingAbstractions += 1
|
58 |
+
newAbstractions = a - existingAbstractions
|
59 |
+
assert newAbstractions >= 0
|
60 |
+
|
61 |
+
# e is the body stripped of abstractions. we are going to pile
|
62 |
+
# some more lambdas at the front, so free variables in e
|
63 |
+
# (which were bound to the stripped abstractions) need to be
|
64 |
+
# shifted by the number of abstractions that we will be adding
|
65 |
+
e = e.shift(newAbstractions)
|
66 |
+
|
67 |
+
for n in reversed(range(newAbstractions)):
|
68 |
+
e = Application(e, Index(n))
|
69 |
+
for _ in range(a):
|
70 |
+
e = Abstraction(e)
|
71 |
+
|
72 |
+
assert self.infer() == e.infer(), \
|
73 |
+
"FATAL: uncurry has a bug. %s : %s, but uncurried to %s : %s" % (self, self.infer(),
|
74 |
+
e, e.infer())
|
75 |
+
return e
|
76 |
+
|
77 |
+
def wellTyped(self):
|
78 |
+
try:
|
79 |
+
self.infer()
|
80 |
+
return True
|
81 |
+
except InferenceFailure:
|
82 |
+
return False
|
83 |
+
|
84 |
+
def runWithArguments(self, xs):
|
85 |
+
f = self.evaluate([])
|
86 |
+
for x in xs:
|
87 |
+
f = f(x)
|
88 |
+
return f
|
89 |
+
|
90 |
+
def applicationParses(self): yield self, []
|
91 |
+
|
92 |
+
def applicationParse(self): return self, []
|
93 |
+
|
94 |
+
@property
|
95 |
+
def closed(self):
|
96 |
+
for surroundingAbstractions, child in self.walk():
|
97 |
+
if isinstance(child, FragmentVariable):
|
98 |
+
return False
|
99 |
+
if isinstance(child, Index) and child.free(
|
100 |
+
surroundingAbstractions):
|
101 |
+
return False
|
102 |
+
return True
|
103 |
+
|
104 |
+
@property
|
105 |
+
def numberOfFreeVariables(expression):
|
106 |
+
n = 0
|
107 |
+
for surroundingAbstractions, child in expression.walk():
|
108 |
+
# Free variable
|
109 |
+
if isinstance(child, Index) and child.free(
|
110 |
+
surroundingAbstractions):
|
111 |
+
n = max(n, child.i - surroundingAbstractions + 1)
|
112 |
+
return n
|
113 |
+
|
114 |
+
def freeVariables(self):
|
115 |
+
for surroundingAbstractions, child in self.walk():
|
116 |
+
if child.isIndex and child.i >= surroundingAbstractions:
|
117 |
+
yield child.i - surroundingAbstractions
|
118 |
+
|
119 |
+
@property
|
120 |
+
def isIndex(self): return False
|
121 |
+
|
122 |
+
@property
|
123 |
+
def isUnion(self): return False
|
124 |
+
|
125 |
+
@property
|
126 |
+
def isApplication(self): return False
|
127 |
+
|
128 |
+
@property
|
129 |
+
def isAbstraction(self): return False
|
130 |
+
|
131 |
+
@property
|
132 |
+
def isPrimitive(self): return False
|
133 |
+
|
134 |
+
@property
|
135 |
+
def isInvented(self): return False
|
136 |
+
|
137 |
+
@property
|
138 |
+
def isHole(self): return False
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
def parse(s):
|
142 |
+
s = parseSExpression(s)
|
143 |
+
def p(e):
|
144 |
+
if isinstance(e,list):
|
145 |
+
if e[0] == '#':
|
146 |
+
assert len(e) == 2
|
147 |
+
return Invented(p(e[1]))
|
148 |
+
if e[0] == 'lambda':
|
149 |
+
assert len(e) == 2
|
150 |
+
return Abstraction(p(e[1]))
|
151 |
+
f = p(e[0])
|
152 |
+
for x in e[1:]:
|
153 |
+
f = Application(f,p(x))
|
154 |
+
return f
|
155 |
+
assert isinstance(e,str)
|
156 |
+
if e[0] == '$': return Index(int(e[1:]))
|
157 |
+
if e in Primitive.GLOBALS: return Primitive.GLOBALS[e]
|
158 |
+
if e == '??' or e == '?': return FragmentVariable.single
|
159 |
+
if e == '<HOLE>': return Hole.single
|
160 |
+
raise ParseFailure((s,e))
|
161 |
+
return p(s)
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def _parse(s,n):
|
165 |
+
while n < len(s) and s[n].isspace():
|
166 |
+
n += 1
|
167 |
+
for p in [
|
168 |
+
Application,
|
169 |
+
Abstraction,
|
170 |
+
Index,
|
171 |
+
Invented,
|
172 |
+
FragmentVariable,
|
173 |
+
Hole,
|
174 |
+
Primitive]:
|
175 |
+
try:
|
176 |
+
return p._parse(s,n)
|
177 |
+
except ParseFailure:
|
178 |
+
continue
|
179 |
+
raise ParseFailure(s)
|
180 |
+
|
181 |
+
# parser helpers
|
182 |
+
@staticmethod
|
183 |
+
def parseConstant(s,n,*constants):
|
184 |
+
for constant in constants:
|
185 |
+
try:
|
186 |
+
for i,c in enumerate(constant):
|
187 |
+
if i + n >= len(s) or s[i + n] != c: raise ParseFailure(s)
|
188 |
+
return n + len(constant)
|
189 |
+
except ParseFailure: continue
|
190 |
+
raise ParseFailure(s)
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def parseHumanReadable(s):
|
194 |
+
s = parseSExpression(s)
|
195 |
+
def p(s, environment):
|
196 |
+
if isinstance(s, list) and s[0] in ['lambda','\\']:
|
197 |
+
assert isinstance(s[1], list) and len(s) == 3
|
198 |
+
newEnvironment = list(reversed(s[1])) + environment
|
199 |
+
e = p(s[2], newEnvironment)
|
200 |
+
for _ in s[1]: e = Abstraction(e)
|
201 |
+
return e
|
202 |
+
if isinstance(s, list):
|
203 |
+
a = p(s[0], environment)
|
204 |
+
for x in s[1:]:
|
205 |
+
a = Application(a, p(x, environment))
|
206 |
+
return a
|
207 |
+
for j,v in enumerate(environment):
|
208 |
+
if s == v: return Index(j)
|
209 |
+
if s in Primitive.GLOBALS: return Primitive.GLOBALS[s]
|
210 |
+
assert False, f"could not parse {s}"
|
211 |
+
return p(s, [])
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
class Application(Program):
|
217 |
+
'''Function application'''
|
218 |
+
|
219 |
+
def __init__(self, f, x):
|
220 |
+
self.f = f
|
221 |
+
self.x = x
|
222 |
+
self.hashCode = None
|
223 |
+
self.isConditional = (not isinstance(f,int)) and \
|
224 |
+
f.isApplication and \
|
225 |
+
f.f.isApplication and \
|
226 |
+
f.f.f.isPrimitive and \
|
227 |
+
f.f.f.name == "if"
|
228 |
+
if self.isConditional:
|
229 |
+
self.falseBranch = x
|
230 |
+
self.trueBranch = f.x
|
231 |
+
self.branch = f.f.x
|
232 |
+
else:
|
233 |
+
self.falseBranch = None
|
234 |
+
self.trueBranch = None
|
235 |
+
self.branch = None
|
236 |
+
|
237 |
+
def betaReduce(self):
|
238 |
+
# See if either the function or the argument can be reduced
|
239 |
+
f = self.f.betaReduce()
|
240 |
+
if f is not None: return Application(f,self.x)
|
241 |
+
x = self.x.betaReduce()
|
242 |
+
if x is not None: return Application(self.f,x)
|
243 |
+
|
244 |
+
# Neither of them could be reduced. Is this not a redex?
|
245 |
+
if not self.f.isAbstraction: return None
|
246 |
+
|
247 |
+
# Perform substitution
|
248 |
+
b = self.f.body
|
249 |
+
v = self.x
|
250 |
+
return b.substitute(Index(0), v.shift(1)).shift(-1)
|
251 |
+
|
252 |
+
def isBetaLong(self):
|
253 |
+
return (not self.f.isAbstraction) and self.f.isBetaLong() and self.x.isBetaLong()
|
254 |
+
|
255 |
+
def freeVariables(self):
|
256 |
+
return self.f.freeVariables() | self.x.freeVariables()
|
257 |
+
|
258 |
+
def clone(self): return Application(self.f.clone(), self.x.clone())
|
259 |
+
|
260 |
+
def annotateTypes(self, context, environment):
|
261 |
+
self.f.annotateTypes(context, environment)
|
262 |
+
self.x.annotateTypes(context, environment)
|
263 |
+
r = context.makeVariable()
|
264 |
+
context.unify(arrow(self.x.annotatedType, r), self.f.annotatedType)
|
265 |
+
self.annotatedType = r.applyMutable(context)
|
266 |
+
|
267 |
+
|
268 |
+
@property
|
269 |
+
def isApplication(self): return True
|
270 |
+
|
271 |
+
def __eq__(
|
272 |
+
self,
|
273 |
+
other): return isinstance(
|
274 |
+
other,
|
275 |
+
Application) and self.f == other.f and self.x == other.x
|
276 |
+
|
277 |
+
def __hash__(self):
|
278 |
+
if self.hashCode is None:
|
279 |
+
self.hashCode = hash((hash(self.f), hash(self.x)))
|
280 |
+
return self.hashCode
|
281 |
+
|
282 |
+
"""Because Python3 randomizes the hash function, we need to never pickle the hash"""
|
283 |
+
def __getstate__(self):
|
284 |
+
return self.f, self.x, self.isConditional, self.falseBranch, self.trueBranch, self.branch
|
285 |
+
def __setstate__(self, state):
|
286 |
+
try:
|
287 |
+
self.f, self.x, self.isConditional, self.falseBranch, self.trueBranch, self.branch = state
|
288 |
+
except ValueError:
|
289 |
+
# backward compatibility
|
290 |
+
assert 'x' in state
|
291 |
+
assert 'f' in state
|
292 |
+
f = state['f']
|
293 |
+
x = state['x']
|
294 |
+
self.f = f
|
295 |
+
self.x = x
|
296 |
+
self.isConditional = (not isinstance(f,int)) and \
|
297 |
+
f.isApplication and \
|
298 |
+
f.f.isApplication and \
|
299 |
+
f.f.f.isPrimitive and \
|
300 |
+
f.f.f.name == "if"
|
301 |
+
if self.isConditional:
|
302 |
+
self.falseBranch = x
|
303 |
+
self.trueBranch = f.x
|
304 |
+
self.branch = f.f.x
|
305 |
+
else:
|
306 |
+
self.falseBranch = None
|
307 |
+
self.trueBranch = None
|
308 |
+
self.branch = None
|
309 |
+
|
310 |
+
self.hashCode = None
|
311 |
+
|
312 |
+
def visit(self,
|
313 |
+
visitor,
|
314 |
+
*arguments,
|
315 |
+
**keywords): return visitor.application(self,
|
316 |
+
*arguments,
|
317 |
+
**keywords)
|
318 |
+
|
319 |
+
def show(self, isFunction):
|
320 |
+
if isFunction:
|
321 |
+
return "%s %s" % (self.f.show(True), self.x.show(False))
|
322 |
+
else:
|
323 |
+
return "(%s %s)" % (self.f.show(True), self.x.show(False))
|
324 |
+
|
325 |
+
def evaluate(self, environment):
|
326 |
+
if self.isConditional:
|
327 |
+
if self.branch.evaluate(environment):
|
328 |
+
return self.trueBranch.evaluate(environment)
|
329 |
+
else:
|
330 |
+
return self.falseBranch.evaluate(environment)
|
331 |
+
else:
|
332 |
+
return self.f.evaluate(environment)(self.x.evaluate(environment))
|
333 |
+
|
334 |
+
def inferType(self, context, environment, freeVariables):
|
335 |
+
(context, ft) = self.f.inferType(context, environment, freeVariables)
|
336 |
+
(context, xt) = self.x.inferType(context, environment, freeVariables)
|
337 |
+
(context, returnType) = context.makeVariable()
|
338 |
+
context = context.unify(ft, arrow(xt, returnType))
|
339 |
+
return (context, returnType.apply(context))
|
340 |
+
|
341 |
+
def applicationParses(self):
|
342 |
+
yield self, []
|
343 |
+
for f, xs in self.f.applicationParses():
|
344 |
+
yield f, xs + [self.x]
|
345 |
+
|
346 |
+
def applicationParse(self):
|
347 |
+
f, xs = self.f.applicationParse()
|
348 |
+
return f, xs + [self.x]
|
349 |
+
|
350 |
+
def shift(self, offset, depth=0):
|
351 |
+
return Application(self.f.shift(offset, depth),
|
352 |
+
self.x.shift(offset, depth))
|
353 |
+
|
354 |
+
def substitute(self, old, new):
|
355 |
+
if self == old:
|
356 |
+
return new
|
357 |
+
return Application(
|
358 |
+
self.f.substitute(
|
359 |
+
old, new), self.x.substitute(
|
360 |
+
old, new))
|
361 |
+
|
362 |
+
def walkUncurried(self, d=0):
|
363 |
+
yield d, self
|
364 |
+
f, xs = self.applicationParse()
|
365 |
+
yield from f.walkUncurried(d)
|
366 |
+
for x in xs:
|
367 |
+
yield from x.walkUncurried(d)
|
368 |
+
|
369 |
+
def walk(self, surroundingAbstractions=0):
|
370 |
+
yield surroundingAbstractions, self
|
371 |
+
yield from self.f.walk(surroundingAbstractions)
|
372 |
+
yield from self.x.walk(surroundingAbstractions)
|
373 |
+
|
374 |
+
def size(self): return self.f.size() + self.x.size()
|
375 |
+
|
376 |
+
@staticmethod
|
377 |
+
def _parse(s,n):
|
378 |
+
while n < len(s) and s[n].isspace(): n += 1
|
379 |
+
if n == len(s) or s[n] != '(': raise ParseFailure(s)
|
380 |
+
n += 1
|
381 |
+
|
382 |
+
xs = []
|
383 |
+
while True:
|
384 |
+
x, n = Program._parse(s, n)
|
385 |
+
xs.append(x)
|
386 |
+
while n < len(s) and s[n].isspace(): n += 1
|
387 |
+
if n == len(s):
|
388 |
+
raise ParseFailure(s)
|
389 |
+
if s[n] == ")":
|
390 |
+
n += 1
|
391 |
+
break
|
392 |
+
e = xs[0]
|
393 |
+
for x in xs[1:]:
|
394 |
+
e = Application(e, x)
|
395 |
+
return e, n
|
396 |
+
|
397 |
+
|
398 |
+
class Index(Program):
|
399 |
+
'''
|
400 |
+
deBruijn index: https://en.wikipedia.org/wiki/De_Bruijn_index
|
401 |
+
These indices encode variables.
|
402 |
+
'''
|
403 |
+
|
404 |
+
def __init__(self, i):
|
405 |
+
self.i = i
|
406 |
+
|
407 |
+
def show(self, isFunction): return "$%d" % self.i
|
408 |
+
|
409 |
+
def __eq__(self, o): return isinstance(o, Index) and o.i == self.i
|
410 |
+
|
411 |
+
def __hash__(self): return self.i
|
412 |
+
|
413 |
+
def visit(self,
|
414 |
+
visitor,
|
415 |
+
*arguments,
|
416 |
+
**keywords): return visitor.index(self,
|
417 |
+
*arguments,
|
418 |
+
**keywords)
|
419 |
+
|
420 |
+
def evaluate(self, environment):
|
421 |
+
return environment[self.i]
|
422 |
+
|
423 |
+
def inferType(self, context, environment, freeVariables):
|
424 |
+
if self.bound(len(environment)):
|
425 |
+
return (context, environment[self.i].apply(context))
|
426 |
+
else:
|
427 |
+
i = self.i - len(environment)
|
428 |
+
if i in freeVariables:
|
429 |
+
return (context, freeVariables[i].apply(context))
|
430 |
+
context, variable = context.makeVariable()
|
431 |
+
freeVariables[i] = variable
|
432 |
+
return (context, variable)
|
433 |
+
|
434 |
+
def clone(self): return Index(self.i)
|
435 |
+
|
436 |
+
def annotateTypes(self, context, environment):
|
437 |
+
self.annotatedType = environment[self.i].applyMutable(context)
|
438 |
+
|
439 |
+
def shift(self, offset, depth=0):
|
440 |
+
# bound variable
|
441 |
+
if self.bound(depth):
|
442 |
+
return self
|
443 |
+
else: # free variable
|
444 |
+
i = self.i + offset
|
445 |
+
if i < 0:
|
446 |
+
raise ShiftFailure()
|
447 |
+
return Index(i)
|
448 |
+
|
449 |
+
def betaReduce(self): return None
|
450 |
+
|
451 |
+
def isBetaLong(self): return True
|
452 |
+
|
453 |
+
def freeVariables(self): return {self.i}
|
454 |
+
|
455 |
+
def substitute(self, old, new):
|
456 |
+
if old == self:
|
457 |
+
return new
|
458 |
+
else:
|
459 |
+
return self
|
460 |
+
|
461 |
+
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
|
462 |
+
|
463 |
+
def walkUncurried(self, d=0): yield d, self
|
464 |
+
|
465 |
+
def size(self): return 1
|
466 |
+
|
467 |
+
def free(self, surroundingAbstractions):
|
468 |
+
'''Is this index a free variable, given that it has surroundingAbstractions lambda's around it?'''
|
469 |
+
return self.i >= surroundingAbstractions
|
470 |
+
|
471 |
+
def bound(self, surroundingAbstractions):
|
472 |
+
'''Is this index a bound variable, given that it has surroundingAbstractions lambda's around it?'''
|
473 |
+
return self.i < surroundingAbstractions
|
474 |
+
|
475 |
+
@property
|
476 |
+
def isIndex(self): return True
|
477 |
+
|
478 |
+
@staticmethod
|
479 |
+
def _parse(s,n):
|
480 |
+
while n < len(s) and s[n].isspace(): n += 1
|
481 |
+
if n == len(s) or s[n] != '$':
|
482 |
+
raise ParseFailure(s)
|
483 |
+
n += 1
|
484 |
+
j = ""
|
485 |
+
while n < len(s) and s[n].isdigit():
|
486 |
+
j += s[n]
|
487 |
+
n += 1
|
488 |
+
if j == "":
|
489 |
+
raise ParseFailure(s)
|
490 |
+
return Index(int(j)), n
|
491 |
+
|
492 |
+
|
493 |
+
class Abstraction(Program):
|
494 |
+
'''Lambda abstraction. Creates a new function.'''
|
495 |
+
|
496 |
+
def __init__(self, body):
|
497 |
+
self.body = body
|
498 |
+
self.hashCode = None
|
499 |
+
|
500 |
+
@property
|
501 |
+
def isAbstraction(self): return True
|
502 |
+
|
503 |
+
def __eq__(self, o): return isinstance(
|
504 |
+
o, Abstraction) and o.body == self.body
|
505 |
+
|
506 |
+
def __hash__(self):
|
507 |
+
if self.hashCode is None:
|
508 |
+
self.hashCode = hash((hash(self.body),))
|
509 |
+
return self.hashCode
|
510 |
+
|
511 |
+
"""Because Python3 randomizes the hash function, we need to never pickle the hash"""
|
512 |
+
def __getstate__(self):
|
513 |
+
return self.body
|
514 |
+
def __setstate__(self, state):
|
515 |
+
self.body = state
|
516 |
+
self.hashCode = None
|
517 |
+
|
518 |
+
def isBetaLong(self): return self.body.isBetaLong()
|
519 |
+
|
520 |
+
def freeVariables(self):
|
521 |
+
return {f - 1 for f in self.body.freeVariables() if f > 0}
|
522 |
+
|
523 |
+
def visit(self,
|
524 |
+
visitor,
|
525 |
+
*arguments,
|
526 |
+
**keywords): return visitor.abstraction(self,
|
527 |
+
*arguments,
|
528 |
+
**keywords)
|
529 |
+
|
530 |
+
def clone(self): return Abstraction(self.body.clone())
|
531 |
+
|
532 |
+
def annotateTypes(self, context, environment):
|
533 |
+
v = context.makeVariable()
|
534 |
+
self.body.annotateTypes(context, [v] + environment)
|
535 |
+
self.annotatedType = arrow(v.applyMutable(context), self.body.annotatedType)
|
536 |
+
|
537 |
+
def show(self, isFunction):
|
538 |
+
return "(lambda %s)" % (self.body.show(False))
|
539 |
+
|
540 |
+
def evaluate(self, environment):
|
541 |
+
return lambda x: self.body.evaluate([x] + environment)
|
542 |
+
|
543 |
+
def betaReduce(self):
|
544 |
+
b = self.body.betaReduce()
|
545 |
+
if b is None: return None
|
546 |
+
return Abstraction(b)
|
547 |
+
|
548 |
+
def inferType(self, context, environment, freeVariables):
|
549 |
+
(context, argumentType) = context.makeVariable()
|
550 |
+
(context, returnType) = self.body.inferType(
|
551 |
+
context, [argumentType] + environment, freeVariables)
|
552 |
+
return (context, arrow(argumentType, returnType).apply(context))
|
553 |
+
|
554 |
+
def shift(self, offset, depth=0):
|
555 |
+
return Abstraction(self.body.shift(offset, depth + 1))
|
556 |
+
|
557 |
+
def substitute(self, old, new):
|
558 |
+
if self == old:
|
559 |
+
return new
|
560 |
+
old = old.shift(1)
|
561 |
+
new = new.shift(1)
|
562 |
+
return Abstraction(self.body.substitute(old, new))
|
563 |
+
|
564 |
+
def walk(self, surroundingAbstractions=0):
|
565 |
+
yield surroundingAbstractions, self
|
566 |
+
yield from self.body.walk(surroundingAbstractions + 1)
|
567 |
+
|
568 |
+
def walkUncurried(self, d=0):
|
569 |
+
yield d, self
|
570 |
+
yield from self.body.walkUncurried(d + 1)
|
571 |
+
|
572 |
+
def size(self): return self.body.size()
|
573 |
+
|
574 |
+
@staticmethod
|
575 |
+
def _parse(s,n):
|
576 |
+
n = Program.parseConstant(s,n,
|
577 |
+
'(\\','(lambda','(\u03bb')
|
578 |
+
|
579 |
+
while n < len(s) and s[n].isspace(): n += 1
|
580 |
+
|
581 |
+
b, n = Program._parse(s,n)
|
582 |
+
while n < len(s) and s[n].isspace(): n += 1
|
583 |
+
n = Program.parseConstant(s,n,')')
|
584 |
+
return Abstraction(b), n
|
585 |
+
|
586 |
+
|
587 |
+
class Primitive(Program):
|
588 |
+
GLOBALS = {}
|
589 |
+
|
590 |
+
def __init__(self, name, ty, value):
|
591 |
+
self.tp = ty
|
592 |
+
self.name = name
|
593 |
+
self.value = value
|
594 |
+
if name not in Primitive.GLOBALS:
|
595 |
+
Primitive.GLOBALS[name] = self
|
596 |
+
|
597 |
+
@property
|
598 |
+
def isPrimitive(self): return True
|
599 |
+
|
600 |
+
def __eq__(self, o): return isinstance(
|
601 |
+
o, Primitive) and o.name == self.name
|
602 |
+
|
603 |
+
def __hash__(self): return hash(self.name)
|
604 |
+
|
605 |
+
def visit(self,
|
606 |
+
visitor,
|
607 |
+
*arguments,
|
608 |
+
**keywords): return visitor.primitive(self,
|
609 |
+
*arguments,
|
610 |
+
**keywords)
|
611 |
+
|
612 |
+
def show(self, isFunction): return self.name
|
613 |
+
|
614 |
+
def clone(self): return Primitive(self.name, self.tp, self.value)
|
615 |
+
|
616 |
+
def annotateTypes(self, context, environment):
|
617 |
+
self.annotatedType = self.tp.instantiateMutable(context)
|
618 |
+
|
619 |
+
def evaluate(self, environment): return self.value
|
620 |
+
|
621 |
+
def betaReduce(self): return None
|
622 |
+
|
623 |
+
def isBetaLong(self): return True
|
624 |
+
|
625 |
+
def freeVariables(self): return set()
|
626 |
+
|
627 |
+
def inferType(self, context, environment, freeVariables):
|
628 |
+
return self.tp.instantiate(context)
|
629 |
+
|
630 |
+
def shift(self, offset, depth=0): return self
|
631 |
+
|
632 |
+
def substitute(self, old, new):
|
633 |
+
if self == old:
|
634 |
+
return new
|
635 |
+
else:
|
636 |
+
return self
|
637 |
+
|
638 |
+
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
|
639 |
+
|
640 |
+
def walkUncurried(self, d=0): yield d, self
|
641 |
+
|
642 |
+
def size(self): return 1
|
643 |
+
|
644 |
+
@staticmethod
|
645 |
+
def _parse(s,n):
|
646 |
+
while n < len(s) and s[n].isspace(): n += 1
|
647 |
+
name = []
|
648 |
+
while n < len(s) and not s[n].isspace() and s[n] not in '()':
|
649 |
+
name.append(s[n])
|
650 |
+
n += 1
|
651 |
+
name = "".join(name)
|
652 |
+
if name in Primitive.GLOBALS:
|
653 |
+
return Primitive.GLOBALS[name], n
|
654 |
+
raise ParseFailure(s)
|
655 |
+
|
656 |
+
# TODO(@mtensor): needs to be fixed to handle both pickling lambda functions and unpickling in general.
|
657 |
+
# def __getstate__(self):
|
658 |
+
# return self.name
|
659 |
+
|
660 |
+
# def __setstate__(self, state):
|
661 |
+
# #for backwards compatibility:
|
662 |
+
# if type(state) == dict:
|
663 |
+
# self.__dict__ = state
|
664 |
+
# else:
|
665 |
+
# p = Primitive.GLOBALS[state]
|
666 |
+
# self.__init__(p.name, p.tp, p.value)
|
667 |
+
|
668 |
+
class Invented(Program):
|
669 |
+
'''New invented primitives'''
|
670 |
+
|
671 |
+
def __init__(self, body):
|
672 |
+
self.body = body
|
673 |
+
self.tp = self.body.infer()
|
674 |
+
self.hashCode = None
|
675 |
+
|
676 |
+
@property
|
677 |
+
def isInvented(self): return True
|
678 |
+
|
679 |
+
def show(self, isFunction): return "#%s" % (self.body.show(False))
|
680 |
+
|
681 |
+
def visit(self,
|
682 |
+
visitor,
|
683 |
+
*arguments,
|
684 |
+
**keywords): return visitor.invented(self,
|
685 |
+
*arguments,
|
686 |
+
**keywords)
|
687 |
+
|
688 |
+
def __eq__(self, o): return isinstance(o, Invented) and o.body == self.body
|
689 |
+
|
690 |
+
def __hash__(self):
|
691 |
+
if self.hashCode is None:
|
692 |
+
self.hashCode = hash((0, hash(self.body)))
|
693 |
+
return self.hashCode
|
694 |
+
|
695 |
+
"""Because Python3 randomizes the hash function, we need to never pickle the hash"""
|
696 |
+
def __getstate__(self):
|
697 |
+
return self.body, self.tp
|
698 |
+
def __setstate__(self, state):
|
699 |
+
self.body, self.tp = state
|
700 |
+
self.hashCode = None
|
701 |
+
|
702 |
+
def clone(self): return Invented(self.body)
|
703 |
+
|
704 |
+
def annotateTypes(self, context, environment):
|
705 |
+
self.annotatedType = self.tp.instantiateMutable(context)
|
706 |
+
|
707 |
+
def evaluate(self, e): return self.body.evaluate([])
|
708 |
+
|
709 |
+
def betaReduce(self): return self.body
|
710 |
+
|
711 |
+
def isBetaLong(self): return True
|
712 |
+
|
713 |
+
def freeVariables(self): return set()
|
714 |
+
|
715 |
+
def inferType(self, context, environment, freeVariables):
|
716 |
+
return self.tp.instantiate(context)
|
717 |
+
|
718 |
+
def shift(self, offset, depth=0): return self
|
719 |
+
|
720 |
+
def substitute(self, old, new):
|
721 |
+
if self == old:
|
722 |
+
return new
|
723 |
+
else:
|
724 |
+
return self
|
725 |
+
|
726 |
+
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
|
727 |
+
|
728 |
+
def walkUncurried(self, d=0): yield d, self
|
729 |
+
|
730 |
+
def size(self): return 1
|
731 |
+
|
732 |
+
@staticmethod
|
733 |
+
def _parse(s,n):
|
734 |
+
while n < len(s) and s[n].isspace(): n += 1
|
735 |
+
if n < len(s) and s[n] == '#':
|
736 |
+
n += 1
|
737 |
+
b,n = Program._parse(s,n)
|
738 |
+
return Invented(b),n
|
739 |
+
|
740 |
+
raise ParseFailure(s)
|
741 |
+
|
742 |
+
|
743 |
+
class FragmentVariable(Program):
|
744 |
+
def __init__(self): pass
|
745 |
+
|
746 |
+
def show(self, isFunction): return "??"
|
747 |
+
|
748 |
+
def __eq__(self, o): return isinstance(o, FragmentVariable)
|
749 |
+
|
750 |
+
def __hash__(self): return 42
|
751 |
+
|
752 |
+
def visit(self, visitor, *arguments, **keywords):
|
753 |
+
return visitor.fragmentVariable(self, *arguments, **keywords)
|
754 |
+
|
755 |
+
def evaluate(self, e):
|
756 |
+
raise Exception('Attempt to evaluate fragment variable')
|
757 |
+
|
758 |
+
def betaReduce(self):
|
759 |
+
raise Exception('Attempt to beta reduce fragment variable')
|
760 |
+
|
761 |
+
def inferType(self, context, environment, freeVariables):
|
762 |
+
return context.makeVariable()
|
763 |
+
|
764 |
+
def shift(self, offset, depth=0):
|
765 |
+
raise Exception('Attempt to shift fragment variable')
|
766 |
+
|
767 |
+
def substitute(self, old, new):
|
768 |
+
if self == old:
|
769 |
+
return new
|
770 |
+
else:
|
771 |
+
return self
|
772 |
+
|
773 |
+
def match(
|
774 |
+
self,
|
775 |
+
context,
|
776 |
+
expression,
|
777 |
+
holes,
|
778 |
+
variableBindings,
|
779 |
+
environment=[]):
|
780 |
+
surroundingAbstractions = len(environment)
|
781 |
+
try:
|
782 |
+
context, variable = context.makeVariable()
|
783 |
+
holes.append(
|
784 |
+
(variable, expression.shift(-surroundingAbstractions)))
|
785 |
+
return context, variable
|
786 |
+
except ShiftFailure:
|
787 |
+
raise MatchFailure()
|
788 |
+
|
789 |
+
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
|
790 |
+
|
791 |
+
def walkUncurried(self, d=0): yield d, self
|
792 |
+
|
793 |
+
def size(self): return 1
|
794 |
+
|
795 |
+
@staticmethod
|
796 |
+
def _parse(s,n):
|
797 |
+
while n < len(s) and s[n].isspace(): n += 1
|
798 |
+
n = Program.parseConstant(s,n,'??','?')
|
799 |
+
return FragmentVariable.single, n
|
800 |
+
|
801 |
+
FragmentVariable.single = FragmentVariable()
|
802 |
+
|
803 |
+
|
804 |
+
class Hole(Program):
|
805 |
+
def __init__(self): pass
|
806 |
+
|
807 |
+
def show(self, isFunction): return "<HOLE>"
|
808 |
+
|
809 |
+
@property
|
810 |
+
def isHole(self): return True
|
811 |
+
|
812 |
+
def __eq__(self, o): return isinstance(o, Hole)
|
813 |
+
|
814 |
+
def __hash__(self): return 42
|
815 |
+
|
816 |
+
def evaluate(self, e):
|
817 |
+
raise Exception('Attempt to evaluate hole')
|
818 |
+
|
819 |
+
def betaReduce(self):
|
820 |
+
raise Exception('Attempt to beta reduce hole')
|
821 |
+
|
822 |
+
def inferType(self, context, environment, freeVariables):
|
823 |
+
return context.makeVariable()
|
824 |
+
|
825 |
+
def shift(self, offset, depth=0):
|
826 |
+
raise Exception('Attempt to shift fragment variable')
|
827 |
+
|
828 |
+
def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
|
829 |
+
|
830 |
+
def walkUncurried(self, d=0): yield d, self
|
831 |
+
|
832 |
+
def size(self): return 1
|
833 |
+
|
834 |
+
@staticmethod
|
835 |
+
def _parse(s,n):
|
836 |
+
while n < len(s) and s[n].isspace(): n += 1
|
837 |
+
n = Program.parseConstant(s,n,
|
838 |
+
'<HOLE>')
|
839 |
+
return Hole.single, n
|
840 |
+
|
841 |
+
|
842 |
+
Hole.single = Hole()
|
843 |
+
|
844 |
+
|
845 |
+
class ShareVisitor(object):
|
846 |
+
def __init__(self):
|
847 |
+
self.primitiveTable = {}
|
848 |
+
self.inventedTable = {}
|
849 |
+
self.indexTable = {}
|
850 |
+
self.applicationTable = {}
|
851 |
+
self.abstractionTable = {}
|
852 |
+
|
853 |
+
def invented(self, e):
|
854 |
+
body = e.body.visit(self)
|
855 |
+
i = id(body)
|
856 |
+
if i in self.inventedTable:
|
857 |
+
return self.inventedTable[i]
|
858 |
+
new = Invented(body)
|
859 |
+
self.inventedTable[i] = new
|
860 |
+
return new
|
861 |
+
|
862 |
+
def primitive(self, e):
|
863 |
+
if e.name in self.primitiveTable:
|
864 |
+
return self.primitiveTable[e.name]
|
865 |
+
self.primitiveTable[e.name] = e
|
866 |
+
return e
|
867 |
+
|
868 |
+
def index(self, e):
|
869 |
+
if e.i in self.indexTable:
|
870 |
+
return self.indexTable[e.i]
|
871 |
+
self.indexTable[e.i] = e
|
872 |
+
return e
|
873 |
+
|
874 |
+
def application(self, e):
|
875 |
+
f = e.f.visit(self)
|
876 |
+
x = e.x.visit(self)
|
877 |
+
fi = id(f)
|
878 |
+
xi = id(x)
|
879 |
+
i = (fi, xi)
|
880 |
+
if i in self.applicationTable:
|
881 |
+
return self.applicationTable[i]
|
882 |
+
new = Application(f, x)
|
883 |
+
self.applicationTable[i] = new
|
884 |
+
return new
|
885 |
+
|
886 |
+
def abstraction(self, e):
|
887 |
+
body = e.body.visit(self)
|
888 |
+
i = id(body)
|
889 |
+
if i in self.abstractionTable:
|
890 |
+
return self.abstractionTable[i]
|
891 |
+
new = Abstraction(body)
|
892 |
+
self.abstractionTable[i] = new
|
893 |
+
return new
|
894 |
+
|
895 |
+
def execute(self, e):
|
896 |
+
return e.visit(self)
|
897 |
+
|
898 |
+
|
899 |
+
class Mutator:
|
900 |
+
"""Perform local mutations to an expr, yielding the expr and the
|
901 |
+
description length distance from the original program"""
|
902 |
+
|
903 |
+
def __init__(self, grammar, fn):
|
904 |
+
"""Fn yields (expression, loglikelihood) from a type and loss.
|
905 |
+
Therefore, loss+loglikelihood is the distance from the original program."""
|
906 |
+
self.fn = fn
|
907 |
+
self.grammar = grammar
|
908 |
+
self.history = []
|
909 |
+
|
910 |
+
def enclose(self, expr):
|
911 |
+
for h in self.history[::-1]:
|
912 |
+
expr = h(expr)
|
913 |
+
return expr
|
914 |
+
|
915 |
+
def invented(self, e, tp, env, is_lhs=False):
|
916 |
+
deleted_ll = self.logLikelihood(tp, e, env)
|
917 |
+
for expr, replaced_ll in self.fn(tp, deleted, is_left_application=is_lhs):
|
918 |
+
yield self.enclose(expr), deleted_ll + replaced_ll
|
919 |
+
|
920 |
+
def primitive(self, e, tp, env, is_lhs=False):
|
921 |
+
deleted_ll = self.logLikelihood(tp, e, env)
|
922 |
+
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
|
923 |
+
yield self.enclose(expr), deleted_ll + replaced_ll
|
924 |
+
|
925 |
+
def index(self, e, tp, env, is_lhs=False):
|
926 |
+
#yield from ()
|
927 |
+
deleted_ll = self.logLikelihood(tp, e, env) #self.grammar.logVariable
|
928 |
+
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
|
929 |
+
yield self.enclose(expr), deleted_ll + replaced_ll
|
930 |
+
|
931 |
+
def application(self, e, tp, env, is_lhs=False):
|
932 |
+
self.history.append(lambda expr: Application(expr, e.x))
|
933 |
+
f_tp = arrow(e.x.infer(), tp)
|
934 |
+
yield from e.f.visit(self, f_tp, env, is_lhs=True)
|
935 |
+
self.history[-1] = lambda expr: Application(e.f, expr)
|
936 |
+
x_tp = inferArg(tp, e.f.infer())
|
937 |
+
yield from e.x.visit(self, x_tp, env)
|
938 |
+
self.history.pop()
|
939 |
+
deleted_ll = self.logLikelihood(tp, e, env)
|
940 |
+
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
|
941 |
+
yield self.enclose(expr), deleted_ll + replaced_ll
|
942 |
+
|
943 |
+
def abstraction(self, e, tp, env, is_lhs=False):
|
944 |
+
self.history.append(lambda expr: Abstraction(expr))
|
945 |
+
yield from e.body.visit(self, tp.arguments[1], [tp.arguments[0]]+env)
|
946 |
+
self.history.pop()
|
947 |
+
deleted_ll = self.logLikelihood(tp, e, env)
|
948 |
+
for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
|
949 |
+
yield self.enclose(expr), deleted_ll + replaced_ll
|
950 |
+
|
951 |
+
def execute(self, e, tp):
|
952 |
+
yield from e.visit(self, tp, [])
|
953 |
+
|
954 |
+
def logLikelihood(self, tp, e, env):
|
955 |
+
summary = None
|
956 |
+
try:
|
957 |
+
_, summary = self.grammar.likelihoodSummary(Context.EMPTY, env,
|
958 |
+
tp, e, silent=True)
|
959 |
+
except AssertionError as err:
|
960 |
+
#print(f"closedLikelihoodSummary failed on tp={tp}, e={e}, error={err}")
|
961 |
+
pass
|
962 |
+
if summary is not None:
|
963 |
+
return summary.logLikelihood(self.grammar)
|
964 |
+
else:
|
965 |
+
tmpE, depth = e, 0
|
966 |
+
while isinstance(tmpE, Abstraction):
|
967 |
+
depth += 1
|
968 |
+
tmpE = tmpE.body
|
969 |
+
to_introduce = len(tp.functionArguments()) - depth
|
970 |
+
if to_introduce == 0:
|
971 |
+
#print(f"HIT NEGATIVEINFINITY, tp={tp}, e={e}")
|
972 |
+
return NEGATIVEINFINITY
|
973 |
+
for i in reversed(range(to_introduce)):
|
974 |
+
e = Application(e, Index(i))
|
975 |
+
for _ in range(to_introduce):
|
976 |
+
e = Abstraction(e)
|
977 |
+
return self.logLikelihood(tp, e, env)
|
978 |
+
|
979 |
+
|
980 |
+
class RegisterPrimitives(object):
|
981 |
+
def invented(self, e): e.body.visit(self)
|
982 |
+
|
983 |
+
def primitive(self, e):
|
984 |
+
if e.name not in Primitive.GLOBALS:
|
985 |
+
Primitive(e.name, e.tp, e.value)
|
986 |
+
|
987 |
+
def index(self, e): pass
|
988 |
+
|
989 |
+
def application(self, e):
|
990 |
+
e.f.visit(self)
|
991 |
+
e.x.visit(self)
|
992 |
+
|
993 |
+
def abstraction(self, e): e.body.visit(self)
|
994 |
+
|
995 |
+
@staticmethod
|
996 |
+
def register(e): e.visit(RegisterPrimitives())
|
997 |
+
|
998 |
+
|
999 |
+
class PrettyVisitor(object):
|
1000 |
+
def __init__(self, Lisp=False):
|
1001 |
+
self.Lisp = Lisp
|
1002 |
+
self.numberOfVariables = 0
|
1003 |
+
self.freeVariables = {}
|
1004 |
+
|
1005 |
+
self.variableNames = ["x", "y", "z", "u", "v", "w"]
|
1006 |
+
self.variableNames += [chr(ord('a') + j)
|
1007 |
+
for j in range(20)]
|
1008 |
+
self.toplevel = True
|
1009 |
+
|
1010 |
+
def makeVariable(self):
|
1011 |
+
v = self.variableNames[self.numberOfVariables]
|
1012 |
+
self.numberOfVariables += 1
|
1013 |
+
return v
|
1014 |
+
|
1015 |
+
def invented(self, e, environment, isFunction, isAbstraction):
|
1016 |
+
s = e.body.visit(self, [], isFunction, isAbstraction)
|
1017 |
+
return s
|
1018 |
+
|
1019 |
+
def primitive(self, e, environment, isVariable, isAbstraction): return e.name
|
1020 |
+
|
1021 |
+
def index(self, e, environment, isVariable, isAbstraction):
|
1022 |
+
if e.i < len(environment):
|
1023 |
+
return environment[e.i]
|
1024 |
+
else:
|
1025 |
+
i = e.i - len(environment)
|
1026 |
+
if i in self.freeVariables:
|
1027 |
+
return self.freeVariables[i]
|
1028 |
+
else:
|
1029 |
+
v = self.makeVariable()
|
1030 |
+
self.freeVariables[i] = v
|
1031 |
+
return v
|
1032 |
+
|
1033 |
+
def application(self, e, environment, isFunction, isAbstraction):
|
1034 |
+
self.toplevel = False
|
1035 |
+
s = "%s %s" % (e.f.visit(self, environment, True, False),
|
1036 |
+
e.x.visit(self, environment, False, False))
|
1037 |
+
if isFunction:
|
1038 |
+
return s
|
1039 |
+
else:
|
1040 |
+
return "(" + s + ")"
|
1041 |
+
|
1042 |
+
def abstraction(self, e, environment, isFunction, isAbstraction):
|
1043 |
+
toplevel = self.toplevel
|
1044 |
+
self.toplevel = False
|
1045 |
+
if not self.Lisp:
|
1046 |
+
# Invent a new variable
|
1047 |
+
v = self.makeVariable()
|
1048 |
+
body = e.body.visit(self,
|
1049 |
+
[v] + environment,
|
1050 |
+
False,
|
1051 |
+
True)
|
1052 |
+
if not e.body.isAbstraction:
|
1053 |
+
body = "." + body
|
1054 |
+
body = v + body
|
1055 |
+
if not isAbstraction:
|
1056 |
+
body = "λ" + body
|
1057 |
+
if not toplevel:
|
1058 |
+
body = "(%s)" % body
|
1059 |
+
return body
|
1060 |
+
else:
|
1061 |
+
child = e
|
1062 |
+
newVariables = []
|
1063 |
+
while child.isAbstraction:
|
1064 |
+
newVariables = [self.makeVariable()] + newVariables
|
1065 |
+
child = child.body
|
1066 |
+
body = child.visit(self, newVariables + environment,
|
1067 |
+
False, True)
|
1068 |
+
body = "(λ (%s) %s)"%(" ".join(reversed(newVariables)), body)
|
1069 |
+
return body
|
1070 |
+
|
1071 |
+
|
1072 |
+
|
1073 |
+
def prettyProgram(e, Lisp=False):
|
1074 |
+
return e.visit(PrettyVisitor(Lisp=Lisp), [], False, False)
|
1075 |
+
|
1076 |
+
class EtaExpandFailure(Exception): pass
|
1077 |
+
class EtaLongVisitor(object):
|
1078 |
+
"""Converts an expression into eta-longform"""
|
1079 |
+
def __init__(self, request=None):
|
1080 |
+
self.request = request
|
1081 |
+
self.context = None
|
1082 |
+
|
1083 |
+
def makeLong(self, e, request):
|
1084 |
+
if request.isArrow():
|
1085 |
+
# eta expansion
|
1086 |
+
return Abstraction(Application(e.shift(1),
|
1087 |
+
Index(0)))
|
1088 |
+
return None
|
1089 |
+
|
1090 |
+
|
1091 |
+
def abstraction(self, e, request, environment):
|
1092 |
+
if not request.isArrow(): raise EtaExpandFailure()
|
1093 |
+
|
1094 |
+
return Abstraction(e.body.visit(self,
|
1095 |
+
request.arguments[1],
|
1096 |
+
[request.arguments[0]] + environment))
|
1097 |
+
|
1098 |
+
def _application(self, e, request, environment):
|
1099 |
+
l = self.makeLong(e, request)
|
1100 |
+
if l is not None: return l.visit(self, request, environment)
|
1101 |
+
|
1102 |
+
f, xs = e.applicationParse()
|
1103 |
+
|
1104 |
+
if f.isIndex:
|
1105 |
+
ft = environment[f.i].applyMutable(self.context)
|
1106 |
+
elif f.isInvented or f.isPrimitive:
|
1107 |
+
ft = f.tp.instantiateMutable(self.context)
|
1108 |
+
else: assert False, "Not in beta long form: %s"%e
|
1109 |
+
|
1110 |
+
self.context.unify(request, ft.returns())
|
1111 |
+
ft = ft.applyMutable(self.context)
|
1112 |
+
|
1113 |
+
xt = ft.functionArguments()
|
1114 |
+
if len(xs) != len(xt): raise EtaExpandFailure()
|
1115 |
+
|
1116 |
+
returnValue = f
|
1117 |
+
for x,t in zip(xs,xt):
|
1118 |
+
t = t.applyMutable(self.context)
|
1119 |
+
returnValue = Application(returnValue,
|
1120 |
+
x.visit(self, t, environment))
|
1121 |
+
return returnValue
|
1122 |
+
|
1123 |
+
# This procedure works by recapitulating the generative process
|
1124 |
+
# applications indices and primitives are all generated identically
|
1125 |
+
|
1126 |
+
def application(self, e, request, environment): return self._application(e, request, environment)
|
1127 |
+
|
1128 |
+
def index(self, e, request, environment): return self._application(e, request, environment)
|
1129 |
+
|
1130 |
+
def primitive(self, e, request, environment): return self._application(e, request, environment)
|
1131 |
+
|
1132 |
+
def invented(self, e, request, environment): return self._application(e, request, environment)
|
1133 |
+
|
1134 |
+
def execute(self, e):
|
1135 |
+
assert len(e.freeVariables()) == 0
|
1136 |
+
|
1137 |
+
if self.request is None:
|
1138 |
+
eprint("WARNING: request not specified for etaexpansion")
|
1139 |
+
self.request = e.infer()
|
1140 |
+
self.context = MutableContext()
|
1141 |
+
el = e.visit(self, self.request, [])
|
1142 |
+
self.context = None
|
1143 |
+
# assert el.infer().canonical() == e.infer().canonical(), \
|
1144 |
+
# f"Types are not preserved by ETA expansion: {e} : {e.infer().canonical()} vs {el} : {el.infer().canonical()}"
|
1145 |
+
return el
|
1146 |
+
|
1147 |
+
|
1148 |
+
|
1149 |
+
class StripPrimitiveVisitor():
|
1150 |
+
"""Replaces all primitives .value's w/ None. Does not destructively modify anything"""
|
1151 |
+
def invented(self,e):
|
1152 |
+
return Invented(e.body.visit(self))
|
1153 |
+
def primitive(self,e):
|
1154 |
+
return Primitive(e.name,e.tp,None)
|
1155 |
+
def application(self,e):
|
1156 |
+
return Application(e.f.visit(self),
|
1157 |
+
e.x.visit(self))
|
1158 |
+
def abstraction(self,e):
|
1159 |
+
return Abstraction(e.body.visit(self))
|
1160 |
+
def index(self,e): return e
|
1161 |
+
|
1162 |
+
class ReplacePrimitiveValueVisitor():
|
1163 |
+
"""Intended to be used after StripPrimitiveVisitor.
|
1164 |
+
Replaces all primitive.value's with their corresponding entry in Primitive.GLOBALS"""
|
1165 |
+
def invented(self,e):
|
1166 |
+
return Invented(e.body.visit(self))
|
1167 |
+
def primitive(self,e):
|
1168 |
+
return Primitive(e.name,e.tp,Primitive.GLOBALS[e.name].value)
|
1169 |
+
def application(self,e):
|
1170 |
+
return Application(e.f.visit(self),
|
1171 |
+
e.x.visit(self))
|
1172 |
+
def abstraction(self,e):
|
1173 |
+
return Abstraction(e.body.visit(self))
|
1174 |
+
def index(self,e): return e
|
1175 |
+
|
1176 |
+
def strip_primitive_values(e):
|
1177 |
+
return e.visit(StripPrimitiveVisitor())
|
1178 |
+
def unstrip_primitive_values(e):
|
1179 |
+
return e.visit(ReplacePrimitiveValueVisitor())
|
1180 |
+
|
1181 |
+
|
1182 |
+
# from luke
|
1183 |
+
class TokeniseVisitor(object):
|
1184 |
+
def invented(self, e):
|
1185 |
+
return [e.body]
|
1186 |
+
|
1187 |
+
def primitive(self, e): return [e.name]
|
1188 |
+
|
1189 |
+
def index(self, e):
|
1190 |
+
return ["$" + str(e.i)]
|
1191 |
+
|
1192 |
+
def application(self, e):
|
1193 |
+
return ["("] + e.f.visit(self) + e.x.visit(self) + [")"]
|
1194 |
+
|
1195 |
+
def abstraction(self, e):
|
1196 |
+
return ["(_lambda"] + e.body.visit(self) + [")_lambda"]
|
1197 |
+
|
1198 |
+
|
1199 |
+
def tokeniseProgram(e):
|
1200 |
+
return e.visit(TokeniseVisitor())
|
1201 |
+
|
1202 |
+
|
1203 |
+
def untokeniseProgram(l):
|
1204 |
+
lookup = {
|
1205 |
+
"(_lambda": "(lambda",
|
1206 |
+
")_lambda": ")"
|
1207 |
+
}
|
1208 |
+
s = " ".join(lookup.get(x, x) for x in l)
|
1209 |
+
return Program.parse(s)
|
1210 |
+
|
1211 |
+
if __name__ == "__main__":
|
1212 |
+
from dreamcoder.domains.arithmetic.arithmeticPrimitives import *
|
1213 |
+
e = Program.parse("(#(lambda (?? (+ 1 $0))) (lambda (?? (+ 1 $0))) (lambda (?? (+ 1 $0))) - * (+ +))")
|
1214 |
+
eprint(e)
|
dreamcoder/recognition.py
ADDED
@@ -0,0 +1,1528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.enumeration import *
|
2 |
+
from dreamcoder.grammar import *
|
3 |
+
# luke
|
4 |
+
|
5 |
+
|
6 |
+
import gc
|
7 |
+
|
8 |
+
try:
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch.autograd import Variable
|
13 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
14 |
+
except:
|
15 |
+
eprint("WARNING: Could not import torch. This is only okay when doing pypy compression.")
|
16 |
+
|
17 |
+
try:
|
18 |
+
import numpy as np
|
19 |
+
except:
|
20 |
+
eprint("WARNING: Could not import np. This is only okay when doing pypy compression.")
|
21 |
+
|
22 |
+
import json
|
23 |
+
|
24 |
+
|
25 |
+
def variable(x, volatile=False, cuda=False):
|
26 |
+
if isinstance(x, list):
|
27 |
+
x = np.array(x)
|
28 |
+
if isinstance(x, (np.ndarray, np.generic)):
|
29 |
+
x = torch.from_numpy(x)
|
30 |
+
if cuda:
|
31 |
+
x = x.cuda()
|
32 |
+
return Variable(x, volatile=volatile)
|
33 |
+
|
34 |
+
def maybe_cuda(x, use_cuda):
|
35 |
+
if use_cuda:
|
36 |
+
return x.cuda()
|
37 |
+
else:
|
38 |
+
return x
|
39 |
+
|
40 |
+
|
41 |
+
def is_torch_not_a_number(v):
|
42 |
+
"""checks whether a tortured variable is nan"""
|
43 |
+
v = v.data
|
44 |
+
if not ((v == v).item()):
|
45 |
+
return True
|
46 |
+
return False
|
47 |
+
|
48 |
+
def is_torch_invalid(v):
|
49 |
+
"""checks whether a torch variable is nan or inf"""
|
50 |
+
if is_torch_not_a_number(v):
|
51 |
+
return True
|
52 |
+
a = v - v
|
53 |
+
if is_torch_not_a_number(a):
|
54 |
+
return True
|
55 |
+
return False
|
56 |
+
|
57 |
+
def _relu(x): return x.clamp(min=0)
|
58 |
+
|
59 |
+
class Entropy(nn.Module):
|
60 |
+
"""Calculates the entropy of logits"""
|
61 |
+
def __init__(self):
|
62 |
+
super(Entropy, self).__init__()
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
b = F.softmax(x, dim=0) * F.log_softmax(x, dim=0)
|
66 |
+
b = -1.0 * b.sum()
|
67 |
+
return b
|
68 |
+
|
69 |
+
class GrammarNetwork(nn.Module):
|
70 |
+
"""Neural network that outputs a grammar"""
|
71 |
+
def __init__(self, inputDimensionality, grammar):
|
72 |
+
super(GrammarNetwork, self).__init__()
|
73 |
+
self.logProductions = nn.Linear(inputDimensionality, len(grammar)+1)
|
74 |
+
self.grammar = grammar
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
"""Takes as input inputDimensionality-dimensional vector and returns Grammar
|
78 |
+
Tensor-valued probabilities"""
|
79 |
+
logProductions = self.logProductions(x)
|
80 |
+
return Grammar(logProductions[-1].view(1), #logVariable
|
81 |
+
[(logProductions[k].view(1), t, program)
|
82 |
+
for k, (_, t, program) in enumerate(self.grammar.productions)],
|
83 |
+
continuationType=self.grammar.continuationType)
|
84 |
+
|
85 |
+
def batchedLogLikelihoods(self, xs, summaries):
|
86 |
+
"""Takes as input BxinputDimensionality vector & B likelihood summaries;
|
87 |
+
returns B-dimensional vector containing log likelihood of each summary"""
|
88 |
+
use_cuda = xs.device.type == 'cuda'
|
89 |
+
|
90 |
+
B = xs.size(0)
|
91 |
+
assert len(summaries) == B
|
92 |
+
logProductions = self.logProductions(xs)
|
93 |
+
|
94 |
+
# uses[b][p] is # uses of primitive p by summary b
|
95 |
+
uses = np.zeros((B,len(self.grammar) + 1))
|
96 |
+
for b,summary in enumerate(summaries):
|
97 |
+
for p, production in enumerate(self.grammar.primitives):
|
98 |
+
uses[b,p] = summary.uses.get(production, 0.)
|
99 |
+
uses[b,len(self.grammar)] = summary.uses.get(Index(0), 0)
|
100 |
+
|
101 |
+
numerator = (logProductions * maybe_cuda(torch.from_numpy(uses).float(),use_cuda)).sum(1)
|
102 |
+
numerator += maybe_cuda(torch.tensor([summary.constant for summary in summaries ]).float(), use_cuda)
|
103 |
+
|
104 |
+
alternativeSet = {normalizer
|
105 |
+
for s in summaries
|
106 |
+
for normalizer in s.normalizers }
|
107 |
+
alternativeSet = list(alternativeSet)
|
108 |
+
|
109 |
+
mask = np.zeros((len(alternativeSet), len(self.grammar) + 1))
|
110 |
+
for tau in range(len(alternativeSet)):
|
111 |
+
for p, production in enumerate(self.grammar.primitives):
|
112 |
+
mask[tau,p] = 0. if production in alternativeSet[tau] else NEGATIVEINFINITY
|
113 |
+
mask[tau,len(self.grammar)] = 0. if Index(0) in alternativeSet[tau] else NEGATIVEINFINITY
|
114 |
+
mask = maybe_cuda(torch.tensor(mask).float(), use_cuda)
|
115 |
+
|
116 |
+
# mask: Rx|G|
|
117 |
+
# logProductions: Bx|G|
|
118 |
+
# Want: mask + logProductions : BxRx|G| = z
|
119 |
+
z = mask.repeat(B,1,1) + logProductions.repeat(len(alternativeSet),1,1).transpose(1,0)
|
120 |
+
# z: BxR
|
121 |
+
z = torch.logsumexp(z, 2) # pytorch 1.0 dependency
|
122 |
+
|
123 |
+
# Calculate how many times each normalizer was used
|
124 |
+
N = np.zeros((B, len(alternativeSet)))
|
125 |
+
for b, summary in enumerate(summaries):
|
126 |
+
for tau, alternatives in enumerate(alternativeSet):
|
127 |
+
N[b, tau] = summary.normalizers.get(alternatives,0.)
|
128 |
+
|
129 |
+
denominator = (maybe_cuda(torch.tensor(N).float(),use_cuda) * z).sum(1)
|
130 |
+
return numerator - denominator
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
class ContextualGrammarNetwork_LowRank(nn.Module):
|
135 |
+
def __init__(self, inputDimensionality, grammar, R=16):
|
136 |
+
"""Low-rank approximation to bigram model. Parameters is linear in number of primitives.
|
137 |
+
R: maximum rank"""
|
138 |
+
|
139 |
+
super(ContextualGrammarNetwork_LowRank, self).__init__()
|
140 |
+
|
141 |
+
self.grammar = grammar
|
142 |
+
|
143 |
+
self.R = R # embedding size
|
144 |
+
|
145 |
+
# library now just contains a list of indicies which go with each primitive
|
146 |
+
self.grammar = grammar
|
147 |
+
self.library = {}
|
148 |
+
self.n_grammars = 0
|
149 |
+
for prim in grammar.primitives:
|
150 |
+
numberOfArguments = len(prim.infer().functionArguments())
|
151 |
+
idx_list = list(range(self.n_grammars, self.n_grammars+numberOfArguments))
|
152 |
+
self.library[prim] = idx_list
|
153 |
+
self.n_grammars += numberOfArguments
|
154 |
+
|
155 |
+
# We had an extra grammar for when there is no parent and for when the parent is a variable
|
156 |
+
self.n_grammars += 2
|
157 |
+
self.transitionMatrix = LowRank(inputDimensionality, self.n_grammars, len(grammar) + 1, R)
|
158 |
+
|
159 |
+
def grammarFromVector(self, logProductions):
|
160 |
+
return Grammar(logProductions[-1].view(1),
|
161 |
+
[(logProductions[k].view(1), t, program)
|
162 |
+
for k, (_, t, program) in enumerate(self.grammar.productions)],
|
163 |
+
continuationType=self.grammar.continuationType)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
assert len(x.size()) == 1, "contextual grammar doesn't currently support batching"
|
167 |
+
|
168 |
+
transitionMatrix = self.transitionMatrix(x)
|
169 |
+
|
170 |
+
return ContextualGrammar(self.grammarFromVector(transitionMatrix[-1]), self.grammarFromVector(transitionMatrix[-2]),
|
171 |
+
{prim: [self.grammarFromVector(transitionMatrix[j]) for j in js]
|
172 |
+
for prim, js in self.library.items()} )
|
173 |
+
|
174 |
+
def vectorizedLogLikelihoods(self, x, summaries):
|
175 |
+
B = len(summaries)
|
176 |
+
G = len(self.grammar) + 1
|
177 |
+
|
178 |
+
# Which column of the transition matrix corresponds to which primitive
|
179 |
+
primitiveColumn = {p: c
|
180 |
+
for c, (_1,_2,p) in enumerate(self.grammar.productions) }
|
181 |
+
primitiveColumn[Index(0)] = G - 1
|
182 |
+
# Which row of the transition matrix corresponds to which context
|
183 |
+
contextRow = {(parent, index): r
|
184 |
+
for parent, indices in self.library.items()
|
185 |
+
for index, r in enumerate(indices) }
|
186 |
+
contextRow[(None,None)] = self.n_grammars - 1
|
187 |
+
contextRow[(Index(0),None)] = self.n_grammars - 2
|
188 |
+
|
189 |
+
transitionMatrix = self.transitionMatrix(x)
|
190 |
+
|
191 |
+
# uses[b][g][p] is # uses of primitive p by summary b for parent g
|
192 |
+
uses = np.zeros((B,self.n_grammars,len(self.grammar)+1))
|
193 |
+
for b,summary in enumerate(summaries):
|
194 |
+
for e, ss in summary.library.items():
|
195 |
+
for g,s in zip(self.library[e], ss):
|
196 |
+
assert g < self.n_grammars - 2
|
197 |
+
for p, production in enumerate(self.grammar.primitives):
|
198 |
+
uses[b,g,p] = s.uses.get(production, 0.)
|
199 |
+
uses[b,g,len(self.grammar)] = s.uses.get(Index(0), 0)
|
200 |
+
|
201 |
+
# noParent: this is the last network output
|
202 |
+
for p, production in enumerate(self.grammar.primitives):
|
203 |
+
uses[b, self.n_grammars - 1, p] = summary.noParent.uses.get(production, 0.)
|
204 |
+
uses[b, self.n_grammars - 1, G - 1] = summary.noParent.uses.get(Index(0), 0.)
|
205 |
+
|
206 |
+
# variableParent: this is the penultimate network output
|
207 |
+
for p, production in enumerate(self.grammar.primitives):
|
208 |
+
uses[b, self.n_grammars - 2, p] = summary.variableParent.uses.get(production, 0.)
|
209 |
+
uses[b, self.n_grammars - 2, G - 1] = summary.variableParent.uses.get(Index(0), 0.)
|
210 |
+
|
211 |
+
uses = maybe_cuda(torch.tensor(uses).float(),use_cuda)
|
212 |
+
numerator = uses.view(B, -1) @ transitionMatrix.view(-1)
|
213 |
+
|
214 |
+
constant = np.zeros(B)
|
215 |
+
for b,summary in enumerate(summaries):
|
216 |
+
constant[b] += summary.noParent.constant + summary.variableParent.constant
|
217 |
+
for ss in summary.library.values():
|
218 |
+
for s in ss:
|
219 |
+
constant[b] += s.constant
|
220 |
+
|
221 |
+
numerator = numerator + maybe_cuda(torch.tensor(constant).float(),use_cuda)
|
222 |
+
|
223 |
+
# Calculate the god-awful denominator
|
224 |
+
# Map from (parent, index, {set-of-alternatives}) to [occurrences-in-summary-zero, occurrences-in-summary-one, ...]
|
225 |
+
alternativeSet = {}
|
226 |
+
for b,summary in enumerate(summaries):
|
227 |
+
for normalizer, frequency in summary.noParent.normalizers.items():
|
228 |
+
k = (None,None,normalizer)
|
229 |
+
alternativeSet[k] = alternativeSet.get(k, np.zeros(B))
|
230 |
+
alternativeSet[k][b] += frequency
|
231 |
+
for normalizer, frequency in summary.variableParent.normalizers.items():
|
232 |
+
k = (Index(0),None,normalizer)
|
233 |
+
alternativeSet[k] = alternativeSet.get(k, np.zeros(B))
|
234 |
+
alternativeSet[k][b] += frequency
|
235 |
+
for parent, ss in summary.library.items():
|
236 |
+
for argumentIndex, s in enumerate(ss):
|
237 |
+
for normalizer, frequency in s.normalizers.items():
|
238 |
+
k = (parent, argumentIndex, normalizer)
|
239 |
+
alternativeSet[k] = alternativeSet.get(k, zeros(B))
|
240 |
+
alternativeSet[k][b] += frequency
|
241 |
+
|
242 |
+
# Calculate each distinct normalizing constant
|
243 |
+
alternativeNormalizer = {}
|
244 |
+
for parent, index, alternatives in alternativeSet:
|
245 |
+
r = transitionMatrix[contextRow[(parent, index)]]
|
246 |
+
entries = r[ [primitiveColumn[alternative] for alternative in alternatives ]]
|
247 |
+
alternativeNormalizer[(parent, index, alternatives)] = torch.logsumexp(entries, dim=0)
|
248 |
+
|
249 |
+
# Concatenate the normalizers into a vector
|
250 |
+
normalizerKeys = list(alternativeSet.keys())
|
251 |
+
normalizerVector = torch.cat([ alternativeNormalizer[k] for k in normalizerKeys])
|
252 |
+
|
253 |
+
assert False, "This function is still in progress."
|
254 |
+
|
255 |
+
|
256 |
+
def batchedLogLikelihoods(self, xs, summaries):
|
257 |
+
"""Takes as input BxinputDimensionality vector & B likelihood summaries;
|
258 |
+
returns B-dimensional vector containing log likelihood of each summary"""
|
259 |
+
use_cuda = xs.device.type == 'cuda'
|
260 |
+
|
261 |
+
B = xs.shape[0]
|
262 |
+
G = len(self.grammar) + 1
|
263 |
+
assert len(summaries) == B
|
264 |
+
|
265 |
+
# logProductions: Bx n_grammars x G
|
266 |
+
logProductions = self.transitionMatrix(xs)
|
267 |
+
# uses[b][g][p] is # uses of primitive p by summary b for parent g
|
268 |
+
uses = np.zeros((B,self.n_grammars,len(self.grammar)+1))
|
269 |
+
for b,summary in enumerate(summaries):
|
270 |
+
for e, ss in summary.library.items():
|
271 |
+
for g,s in zip(self.library[e], ss):
|
272 |
+
assert g < self.n_grammars - 2
|
273 |
+
for p, production in enumerate(self.grammar.primitives):
|
274 |
+
uses[b,g,p] = s.uses.get(production, 0.)
|
275 |
+
uses[b,g,len(self.grammar)] = s.uses.get(Index(0), 0)
|
276 |
+
|
277 |
+
# noParent: this is the last network output
|
278 |
+
for p, production in enumerate(self.grammar.primitives):
|
279 |
+
uses[b, self.n_grammars - 1, p] = summary.noParent.uses.get(production, 0.)
|
280 |
+
uses[b, self.n_grammars - 1, G - 1] = summary.noParent.uses.get(Index(0), 0.)
|
281 |
+
|
282 |
+
# variableParent: this is the penultimate network output
|
283 |
+
for p, production in enumerate(self.grammar.primitives):
|
284 |
+
uses[b, self.n_grammars - 2, p] = summary.variableParent.uses.get(production, 0.)
|
285 |
+
uses[b, self.n_grammars - 2, G - 1] = summary.variableParent.uses.get(Index(0), 0.)
|
286 |
+
|
287 |
+
numerator = (logProductions*maybe_cuda(torch.tensor(uses).float(),use_cuda)).view(B,-1).sum(1)
|
288 |
+
|
289 |
+
constant = np.zeros(B)
|
290 |
+
for b,summary in enumerate(summaries):
|
291 |
+
constant[b] += summary.noParent.constant + summary.variableParent.constant
|
292 |
+
for ss in summary.library.values():
|
293 |
+
for s in ss:
|
294 |
+
constant[b] += s.constant
|
295 |
+
|
296 |
+
numerator += maybe_cuda(torch.tensor(constant).float(),use_cuda)
|
297 |
+
|
298 |
+
if True:
|
299 |
+
|
300 |
+
# Calculate the god-awful denominator
|
301 |
+
alternativeSet = set()
|
302 |
+
for summary in summaries:
|
303 |
+
for normalizer in summary.noParent.normalizers: alternativeSet.add(normalizer)
|
304 |
+
for normalizer in summary.variableParent.normalizers: alternativeSet.add(normalizer)
|
305 |
+
for ss in summary.library.values():
|
306 |
+
for s in ss:
|
307 |
+
for normalizer in s.normalizers: alternativeSet.add(normalizer)
|
308 |
+
alternativeSet = list(alternativeSet)
|
309 |
+
|
310 |
+
mask = np.zeros((len(alternativeSet), G))
|
311 |
+
for tau in range(len(alternativeSet)):
|
312 |
+
for p, production in enumerate(self.grammar.primitives):
|
313 |
+
mask[tau,p] = 0. if production in alternativeSet[tau] else NEGATIVEINFINITY
|
314 |
+
mask[tau, G - 1] = 0. if Index(0) in alternativeSet[tau] else NEGATIVEINFINITY
|
315 |
+
mask = maybe_cuda(torch.tensor(mask).float(), use_cuda)
|
316 |
+
|
317 |
+
z = mask.repeat(self.n_grammars,1,1).repeat(B,1,1,1) + \
|
318 |
+
logProductions.repeat(len(alternativeSet),1,1,1).transpose(0,1).transpose(1,2)
|
319 |
+
z = torch.logsumexp(z, 3) # pytorch 1.0 dependency
|
320 |
+
|
321 |
+
N = np.zeros((B, self.n_grammars, len(alternativeSet)))
|
322 |
+
for b, summary in enumerate(summaries):
|
323 |
+
for e, ss in summary.library.items():
|
324 |
+
for g,s in zip(self.library[e], ss):
|
325 |
+
assert g < self.n_grammars - 2
|
326 |
+
for r, alternatives in enumerate(alternativeSet):
|
327 |
+
N[b,g,r] = s.normalizers.get(alternatives, 0.)
|
328 |
+
# noParent: this is the last network output
|
329 |
+
for r, alternatives in enumerate(alternativeSet):
|
330 |
+
N[b,self.n_grammars - 1,r] = summary.noParent.normalizers.get(alternatives, 0.)
|
331 |
+
# variableParent: this is the penultimate network output
|
332 |
+
for r, alternatives in enumerate(alternativeSet):
|
333 |
+
N[b,self.n_grammars - 2,r] = summary.variableParent.normalizers.get(alternatives, 0.)
|
334 |
+
N = maybe_cuda(torch.tensor(N).float(), use_cuda)
|
335 |
+
denominator = (N*z).sum(1).sum(1)
|
336 |
+
else:
|
337 |
+
gs = [ self(xs[b]) for b in range(B) ]
|
338 |
+
denominator = torch.cat([ summary.denominator(g) for summary,g in zip(summaries, gs) ])
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
ll = numerator - denominator
|
345 |
+
|
346 |
+
if False: # verifying that batching works correctly
|
347 |
+
gs = [ self(xs[b]) for b in range(B) ]
|
348 |
+
_l = torch.cat([ summary.logLikelihood(g) for summary,g in zip(summaries, gs) ])
|
349 |
+
assert torch.all((ll - _l).abs() < 0.0001)
|
350 |
+
return ll
|
351 |
+
|
352 |
+
class ContextualGrammarNetwork_Mask(nn.Module):
|
353 |
+
def __init__(self, inputDimensionality, grammar):
|
354 |
+
"""Bigram model, but where the bigram transitions are unconditional.
|
355 |
+
Individual primitive probabilities are still conditional (predicted by neural network)
|
356 |
+
"""
|
357 |
+
|
358 |
+
super(ContextualGrammarNetwork_Mask, self).__init__()
|
359 |
+
|
360 |
+
self.grammar = grammar
|
361 |
+
|
362 |
+
# library now just contains a list of indicies which go with each primitive
|
363 |
+
self.grammar = grammar
|
364 |
+
self.library = {}
|
365 |
+
self.n_grammars = 0
|
366 |
+
for prim in grammar.primitives:
|
367 |
+
numberOfArguments = len(prim.infer().functionArguments())
|
368 |
+
idx_list = list(range(self.n_grammars, self.n_grammars+numberOfArguments))
|
369 |
+
self.library[prim] = idx_list
|
370 |
+
self.n_grammars += numberOfArguments
|
371 |
+
|
372 |
+
# We had an extra grammar for when there is no parent and for when the parent is a variable
|
373 |
+
self.n_grammars += 2
|
374 |
+
self._transitionMatrix = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(self.n_grammars, len(grammar) + 1)))
|
375 |
+
self._logProductions = nn.Linear(inputDimensionality, len(grammar)+1)
|
376 |
+
|
377 |
+
def transitionMatrix(self, x):
|
378 |
+
if len(x.shape) == 1: # not batched
|
379 |
+
return self._logProductions(x) + self._transitionMatrix # will broadcast
|
380 |
+
elif len(x.shape) == 2: # batched
|
381 |
+
return self._logProductions(x).unsqueeze(1).repeat(1,self.n_grammars,1) + \
|
382 |
+
self._transitionMatrix.unsqueeze(0).repeat(x.size(0),1,1)
|
383 |
+
else:
|
384 |
+
assert False, "unknown shape for transition matrix input"
|
385 |
+
|
386 |
+
def grammarFromVector(self, logProductions):
|
387 |
+
return Grammar(logProductions[-1].view(1),
|
388 |
+
[(logProductions[k].view(1), t, program)
|
389 |
+
for k, (_, t, program) in enumerate(self.grammar.productions)],
|
390 |
+
continuationType=self.grammar.continuationType)
|
391 |
+
|
392 |
+
def forward(self, x):
|
393 |
+
assert len(x.size()) == 1, "contextual grammar doesn't currently support batching"
|
394 |
+
|
395 |
+
transitionMatrix = self.transitionMatrix(x)
|
396 |
+
|
397 |
+
return ContextualGrammar(self.grammarFromVector(transitionMatrix[-1]), self.grammarFromVector(transitionMatrix[-2]),
|
398 |
+
{prim: [self.grammarFromVector(transitionMatrix[j]) for j in js]
|
399 |
+
for prim, js in self.library.items()} )
|
400 |
+
|
401 |
+
def batchedLogLikelihoods(self, xs, summaries):
|
402 |
+
"""Takes as input BxinputDimensionality vector & B likelihood summaries;
|
403 |
+
returns B-dimensional vector containing log likelihood of each summary"""
|
404 |
+
use_cuda = xs.device.type == 'cuda'
|
405 |
+
|
406 |
+
B = xs.shape[0]
|
407 |
+
G = len(self.grammar) + 1
|
408 |
+
assert len(summaries) == B
|
409 |
+
|
410 |
+
# logProductions: Bx n_grammars x G
|
411 |
+
logProductions = self.transitionMatrix(xs)
|
412 |
+
# uses[b][g][p] is # uses of primitive p by summary b for parent g
|
413 |
+
uses = np.zeros((B,self.n_grammars,len(self.grammar)+1))
|
414 |
+
for b,summary in enumerate(summaries):
|
415 |
+
for e, ss in summary.library.items():
|
416 |
+
for g,s in zip(self.library[e], ss):
|
417 |
+
assert g < self.n_grammars - 2
|
418 |
+
for p, production in enumerate(self.grammar.primitives):
|
419 |
+
uses[b,g,p] = s.uses.get(production, 0.)
|
420 |
+
uses[b,g,len(self.grammar)] = s.uses.get(Index(0), 0)
|
421 |
+
|
422 |
+
# noParent: this is the last network output
|
423 |
+
for p, production in enumerate(self.grammar.primitives):
|
424 |
+
uses[b, self.n_grammars - 1, p] = summary.noParent.uses.get(production, 0.)
|
425 |
+
uses[b, self.n_grammars - 1, G - 1] = summary.noParent.uses.get(Index(0), 0.)
|
426 |
+
|
427 |
+
# variableParent: this is the penultimate network output
|
428 |
+
for p, production in enumerate(self.grammar.primitives):
|
429 |
+
uses[b, self.n_grammars - 2, p] = summary.variableParent.uses.get(production, 0.)
|
430 |
+
uses[b, self.n_grammars - 2, G - 1] = summary.variableParent.uses.get(Index(0), 0.)
|
431 |
+
|
432 |
+
numerator = (logProductions*maybe_cuda(torch.tensor(uses).float(),use_cuda)).view(B,-1).sum(1)
|
433 |
+
|
434 |
+
constant = np.zeros(B)
|
435 |
+
for b,summary in enumerate(summaries):
|
436 |
+
constant[b] += summary.noParent.constant + summary.variableParent.constant
|
437 |
+
for ss in summary.library.values():
|
438 |
+
for s in ss:
|
439 |
+
constant[b] += s.constant
|
440 |
+
|
441 |
+
numerator += maybe_cuda(torch.tensor(constant).float(),use_cuda)
|
442 |
+
|
443 |
+
if True:
|
444 |
+
|
445 |
+
# Calculate the god-awful denominator
|
446 |
+
alternativeSet = set()
|
447 |
+
for summary in summaries:
|
448 |
+
for normalizer in summary.noParent.normalizers: alternativeSet.add(normalizer)
|
449 |
+
for normalizer in summary.variableParent.normalizers: alternativeSet.add(normalizer)
|
450 |
+
for ss in summary.library.values():
|
451 |
+
for s in ss:
|
452 |
+
for normalizer in s.normalizers: alternativeSet.add(normalizer)
|
453 |
+
alternativeSet = list(alternativeSet)
|
454 |
+
|
455 |
+
mask = np.zeros((len(alternativeSet), G))
|
456 |
+
for tau in range(len(alternativeSet)):
|
457 |
+
for p, production in enumerate(self.grammar.primitives):
|
458 |
+
mask[tau,p] = 0. if production in alternativeSet[tau] else NEGATIVEINFINITY
|
459 |
+
mask[tau, G - 1] = 0. if Index(0) in alternativeSet[tau] else NEGATIVEINFINITY
|
460 |
+
mask = maybe_cuda(torch.tensor(mask).float(), use_cuda)
|
461 |
+
|
462 |
+
z = mask.repeat(self.n_grammars,1,1).repeat(B,1,1,1) + \
|
463 |
+
logProductions.repeat(len(alternativeSet),1,1,1).transpose(0,1).transpose(1,2)
|
464 |
+
z = torch.logsumexp(z, 3) # pytorch 1.0 dependency
|
465 |
+
|
466 |
+
N = np.zeros((B, self.n_grammars, len(alternativeSet)))
|
467 |
+
for b, summary in enumerate(summaries):
|
468 |
+
for e, ss in summary.library.items():
|
469 |
+
for g,s in zip(self.library[e], ss):
|
470 |
+
assert g < self.n_grammars - 2
|
471 |
+
for r, alternatives in enumerate(alternativeSet):
|
472 |
+
N[b,g,r] = s.normalizers.get(alternatives, 0.)
|
473 |
+
# noParent: this is the last network output
|
474 |
+
for r, alternatives in enumerate(alternativeSet):
|
475 |
+
N[b,self.n_grammars - 1,r] = summary.noParent.normalizers.get(alternatives, 0.)
|
476 |
+
# variableParent: this is the penultimate network output
|
477 |
+
for r, alternatives in enumerate(alternativeSet):
|
478 |
+
N[b,self.n_grammars - 2,r] = summary.variableParent.normalizers.get(alternatives, 0.)
|
479 |
+
N = maybe_cuda(torch.tensor(N).float(), use_cuda)
|
480 |
+
denominator = (N*z).sum(1).sum(1)
|
481 |
+
else:
|
482 |
+
gs = [ self(xs[b]) for b in range(B) ]
|
483 |
+
denominator = torch.cat([ summary.denominator(g) for summary,g in zip(summaries, gs) ])
|
484 |
+
|
485 |
+
|
486 |
+
|
487 |
+
|
488 |
+
|
489 |
+
ll = numerator - denominator
|
490 |
+
|
491 |
+
if False: # verifying that batching works correctly
|
492 |
+
gs = [ self(xs[b]) for b in range(B) ]
|
493 |
+
_l = torch.cat([ summary.logLikelihood(g) for summary,g in zip(summaries, gs) ])
|
494 |
+
assert torch.all((ll - _l).abs() < 0.0001)
|
495 |
+
return ll
|
496 |
+
|
497 |
+
|
498 |
+
|
499 |
+
class ContextualGrammarNetwork(nn.Module):
|
500 |
+
"""Like GrammarNetwork but ~contextual~"""
|
501 |
+
def __init__(self, inputDimensionality, grammar):
|
502 |
+
super(ContextualGrammarNetwork, self).__init__()
|
503 |
+
|
504 |
+
# library now just contains a list of indicies which go with each primitive
|
505 |
+
self.grammar = grammar
|
506 |
+
self.library = {}
|
507 |
+
self.n_grammars = 0
|
508 |
+
for prim in grammar.primitives:
|
509 |
+
numberOfArguments = len(prim.infer().functionArguments())
|
510 |
+
idx_list = list(range(self.n_grammars, self.n_grammars+numberOfArguments))
|
511 |
+
self.library[prim] = idx_list
|
512 |
+
self.n_grammars += numberOfArguments
|
513 |
+
|
514 |
+
# We had an extra grammar for when there is no parent and for when the parent is a variable
|
515 |
+
self.n_grammars += 2
|
516 |
+
self.network = nn.Linear(inputDimensionality, (self.n_grammars)*(len(grammar) + 1))
|
517 |
+
|
518 |
+
|
519 |
+
def grammarFromVector(self, logProductions):
|
520 |
+
return Grammar(logProductions[-1].view(1),
|
521 |
+
[(logProductions[k].view(1), t, program)
|
522 |
+
for k, (_, t, program) in enumerate(self.grammar.productions)],
|
523 |
+
continuationType=self.grammar.continuationType)
|
524 |
+
|
525 |
+
def forward(self, x):
|
526 |
+
assert len(x.size()) == 1, "contextual grammar doesn't currently support batching"
|
527 |
+
|
528 |
+
allVars = self.network(x).view(self.n_grammars, -1)
|
529 |
+
return ContextualGrammar(self.grammarFromVector(allVars[-1]), self.grammarFromVector(allVars[-2]),
|
530 |
+
{prim: [self.grammarFromVector(allVars[j]) for j in js]
|
531 |
+
for prim, js in self.library.items()} )
|
532 |
+
|
533 |
+
def batchedLogLikelihoods(self, xs, summaries):
|
534 |
+
use_cuda = xs.device.type == 'cuda'
|
535 |
+
"""Takes as input BxinputDimensionality vector & B likelihood summaries;
|
536 |
+
returns B-dimensional vector containing log likelihood of each summary"""
|
537 |
+
|
538 |
+
B = xs.shape[0]
|
539 |
+
G = len(self.grammar) + 1
|
540 |
+
assert len(summaries) == B
|
541 |
+
|
542 |
+
# logProductions: Bx n_grammars x G
|
543 |
+
logProductions = self.network(xs).view(B, self.n_grammars, G)
|
544 |
+
# uses[b][g][p] is # uses of primitive p by summary b for parent g
|
545 |
+
uses = np.zeros((B,self.n_grammars,len(self.grammar)+1))
|
546 |
+
for b,summary in enumerate(summaries):
|
547 |
+
for e, ss in summary.library.items():
|
548 |
+
for g,s in zip(self.library[e], ss):
|
549 |
+
assert g < self.n_grammars - 2
|
550 |
+
for p, production in enumerate(self.grammar.primitives):
|
551 |
+
uses[b,g,p] = s.uses.get(production, 0.)
|
552 |
+
uses[b,g,len(self.grammar)] = s.uses.get(Index(0), 0)
|
553 |
+
|
554 |
+
# noParent: this is the last network output
|
555 |
+
for p, production in enumerate(self.grammar.primitives):
|
556 |
+
uses[b, self.n_grammars - 1, p] = summary.noParent.uses.get(production, 0.)
|
557 |
+
uses[b, self.n_grammars - 1, G - 1] = summary.noParent.uses.get(Index(0), 0.)
|
558 |
+
|
559 |
+
# variableParent: this is the penultimate network output
|
560 |
+
for p, production in enumerate(self.grammar.primitives):
|
561 |
+
uses[b, self.n_grammars - 2, p] = summary.variableParent.uses.get(production, 0.)
|
562 |
+
uses[b, self.n_grammars - 2, G - 1] = summary.variableParent.uses.get(Index(0), 0.)
|
563 |
+
|
564 |
+
numerator = (logProductions*maybe_cuda(torch.tensor(uses).float(),use_cuda)).view(B,-1).sum(1)
|
565 |
+
|
566 |
+
constant = np.zeros(B)
|
567 |
+
for b,summary in enumerate(summaries):
|
568 |
+
constant[b] += summary.noParent.constant + summary.variableParent.constant
|
569 |
+
for ss in summary.library.values():
|
570 |
+
for s in ss:
|
571 |
+
constant[b] += s.constant
|
572 |
+
|
573 |
+
numerator += maybe_cuda(torch.tensor(constant).float(),use_cuda)
|
574 |
+
|
575 |
+
# Calculate the god-awful denominator
|
576 |
+
alternativeSet = set()
|
577 |
+
for summary in summaries:
|
578 |
+
for normalizer in summary.noParent.normalizers: alternativeSet.add(normalizer)
|
579 |
+
for normalizer in summary.variableParent.normalizers: alternativeSet.add(normalizer)
|
580 |
+
for ss in summary.library.values():
|
581 |
+
for s in ss:
|
582 |
+
for normalizer in s.normalizers: alternativeSet.add(normalizer)
|
583 |
+
alternativeSet = list(alternativeSet)
|
584 |
+
|
585 |
+
mask = np.zeros((len(alternativeSet), G))
|
586 |
+
for tau in range(len(alternativeSet)):
|
587 |
+
for p, production in enumerate(self.grammar.primitives):
|
588 |
+
mask[tau,p] = 0. if production in alternativeSet[tau] else NEGATIVEINFINITY
|
589 |
+
mask[tau, G - 1] = 0. if Index(0) in alternativeSet[tau] else NEGATIVEINFINITY
|
590 |
+
mask = maybe_cuda(torch.tensor(mask).float(), use_cuda)
|
591 |
+
|
592 |
+
z = mask.repeat(self.n_grammars,1,1).repeat(B,1,1,1) + \
|
593 |
+
logProductions.repeat(len(alternativeSet),1,1,1).transpose(0,1).transpose(1,2)
|
594 |
+
z = torch.logsumexp(z, 3) # pytorch 1.0 dependency
|
595 |
+
|
596 |
+
N = np.zeros((B, self.n_grammars, len(alternativeSet)))
|
597 |
+
for b, summary in enumerate(summaries):
|
598 |
+
for e, ss in summary.library.items():
|
599 |
+
for g,s in zip(self.library[e], ss):
|
600 |
+
assert g < self.n_grammars - 2
|
601 |
+
for r, alternatives in enumerate(alternativeSet):
|
602 |
+
N[b,g,r] = s.normalizers.get(alternatives, 0.)
|
603 |
+
# noParent: this is the last network output
|
604 |
+
for r, alternatives in enumerate(alternativeSet):
|
605 |
+
N[b,self.n_grammars - 1,r] = summary.noParent.normalizers.get(alternatives, 0.)
|
606 |
+
# variableParent: this is the penultimate network output
|
607 |
+
for r, alternatives in enumerate(alternativeSet):
|
608 |
+
N[b,self.n_grammars - 2,r] = summary.variableParent.normalizers.get(alternatives, 0.)
|
609 |
+
N = maybe_cuda(torch.tensor(N).float(), use_cuda)
|
610 |
+
|
611 |
+
|
612 |
+
|
613 |
+
denominator = (N*z).sum(1).sum(1)
|
614 |
+
ll = numerator - denominator
|
615 |
+
|
616 |
+
if False: # verifying that batching works correctly
|
617 |
+
gs = [ self(xs[b]) for b in range(B) ]
|
618 |
+
_l = torch.cat([ summary.logLikelihood(g) for summary,g in zip(summaries, gs) ])
|
619 |
+
assert torch.all((ll - _l).abs() < 0.0001)
|
620 |
+
|
621 |
+
return ll
|
622 |
+
|
623 |
+
|
624 |
+
class RecognitionModel(nn.Module):
|
625 |
+
def __init__(self,featureExtractor,grammar,hidden=[64],activation="tanh",
|
626 |
+
rank=None,contextual=False,mask=False,
|
627 |
+
cuda=False,
|
628 |
+
previousRecognitionModel=None,
|
629 |
+
id=0):
|
630 |
+
super(RecognitionModel, self).__init__()
|
631 |
+
self.id = id
|
632 |
+
self.trained=False
|
633 |
+
self.use_cuda = cuda
|
634 |
+
|
635 |
+
self.featureExtractor = featureExtractor
|
636 |
+
# Sanity check - make sure that all of the parameters of the
|
637 |
+
# feature extractor were added to our parameters as well
|
638 |
+
if hasattr(featureExtractor, 'parameters'):
|
639 |
+
for parameter in featureExtractor.parameters():
|
640 |
+
assert any(myParameter is parameter for myParameter in self.parameters())
|
641 |
+
|
642 |
+
# Build the multilayer perceptron that is sandwiched between the feature extractor and the grammar
|
643 |
+
if activation == "sigmoid":
|
644 |
+
activation = nn.Sigmoid
|
645 |
+
elif activation == "relu":
|
646 |
+
activation = nn.ReLU
|
647 |
+
elif activation == "tanh":
|
648 |
+
activation = nn.Tanh
|
649 |
+
else:
|
650 |
+
raise Exception('Unknown activation function ' + str(activation))
|
651 |
+
self._MLP = nn.Sequential(*[ layer
|
652 |
+
for j in range(len(hidden))
|
653 |
+
for layer in [
|
654 |
+
nn.Linear(([featureExtractor.outputDimensionality] + hidden)[j],
|
655 |
+
hidden[j]),
|
656 |
+
activation()]])
|
657 |
+
|
658 |
+
self.entropy = Entropy()
|
659 |
+
|
660 |
+
if len(hidden) > 0:
|
661 |
+
self.outputDimensionality = self._MLP[-2].out_features
|
662 |
+
assert self.outputDimensionality == hidden[-1]
|
663 |
+
else:
|
664 |
+
self.outputDimensionality = self.featureExtractor.outputDimensionality
|
665 |
+
|
666 |
+
self.contextual = contextual
|
667 |
+
if self.contextual:
|
668 |
+
if mask:
|
669 |
+
self.grammarBuilder = ContextualGrammarNetwork_Mask(self.outputDimensionality, grammar)
|
670 |
+
else:
|
671 |
+
self.grammarBuilder = ContextualGrammarNetwork_LowRank(self.outputDimensionality, grammar, rank)
|
672 |
+
else:
|
673 |
+
self.grammarBuilder = GrammarNetwork(self.outputDimensionality, grammar)
|
674 |
+
|
675 |
+
self.grammar = ContextualGrammar.fromGrammar(grammar) if contextual else grammar
|
676 |
+
self.generativeModel = grammar
|
677 |
+
|
678 |
+
self._auxiliaryPrediction = nn.Linear(self.featureExtractor.outputDimensionality,
|
679 |
+
len(self.grammar.primitives))
|
680 |
+
self._auxiliaryLoss = nn.BCEWithLogitsLoss()
|
681 |
+
|
682 |
+
if cuda: self.cuda()
|
683 |
+
|
684 |
+
if previousRecognitionModel:
|
685 |
+
self._MLP.load_state_dict(previousRecognitionModel._MLP.state_dict())
|
686 |
+
self.featureExtractor.load_state_dict(previousRecognitionModel.featureExtractor.state_dict())
|
687 |
+
|
688 |
+
def auxiliaryLoss(self, frontier, features):
|
689 |
+
# Compute a vector of uses
|
690 |
+
ls = frontier.bestPosterior.program
|
691 |
+
def uses(summary):
|
692 |
+
if hasattr(summary, 'uses'):
|
693 |
+
return torch.tensor([ float(int(p in summary.uses))
|
694 |
+
for p in self.generativeModel.primitives ])
|
695 |
+
assert hasattr(summary, 'noParent')
|
696 |
+
u = uses(summary.noParent) + uses(summary.variableParent)
|
697 |
+
for ss in summary.library.values():
|
698 |
+
for s in ss:
|
699 |
+
u += uses(s)
|
700 |
+
return u
|
701 |
+
u = uses(ls)
|
702 |
+
u[u > 1.] = 1.
|
703 |
+
if self.use_cuda: u = u.cuda()
|
704 |
+
al = self._auxiliaryLoss(self._auxiliaryPrediction(features), u)
|
705 |
+
return al
|
706 |
+
|
707 |
+
def taskEmbeddings(self, tasks):
|
708 |
+
return {task: self.featureExtractor.featuresOfTask(task).data.cpu().numpy()
|
709 |
+
for task in tasks}
|
710 |
+
|
711 |
+
def forward(self, features):
|
712 |
+
"""returns either a Grammar or a ContextualGrammar
|
713 |
+
Takes as input the output of featureExtractor.featuresOfTask"""
|
714 |
+
features = self._MLP(features)
|
715 |
+
return self.grammarBuilder(features)
|
716 |
+
|
717 |
+
def auxiliaryPrimitiveEmbeddings(self):
|
718 |
+
"""Returns the actual outputDimensionality weight vectors for each of the primitives."""
|
719 |
+
auxiliaryWeights = self._auxiliaryPrediction.weight.data.cpu().numpy()
|
720 |
+
primitivesDict = {self.grammar.primitives[i] : auxiliaryWeights[i, :] for i in range(len(self.grammar.primitives))}
|
721 |
+
return primitivesDict
|
722 |
+
|
723 |
+
def grammarOfTask(self, task):
|
724 |
+
features = self.featureExtractor.featuresOfTask(task)
|
725 |
+
if features is None: return None
|
726 |
+
return self(features)
|
727 |
+
|
728 |
+
def grammarLogProductionsOfTask(self, task):
|
729 |
+
"""Returns the grammar logits from non-contextual models."""
|
730 |
+
|
731 |
+
features = self.featureExtractor.featuresOfTask(task)
|
732 |
+
if features is None: return None
|
733 |
+
|
734 |
+
if hasattr(self, 'hiddenLayers'):
|
735 |
+
# Backward compatability with old checkpoints.
|
736 |
+
for layer in self.hiddenLayers:
|
737 |
+
features = self.activation(layer(features))
|
738 |
+
# return features
|
739 |
+
return self.noParent[1](features)
|
740 |
+
else:
|
741 |
+
features = self._MLP(features)
|
742 |
+
|
743 |
+
if self.contextual:
|
744 |
+
if hasattr(self.grammarBuilder, 'variableParent'):
|
745 |
+
return self.grammarBuilder.variableParent.logProductions(features)
|
746 |
+
elif hasattr(self.grammarBuilder, 'network'):
|
747 |
+
return self.grammarBuilder.network(features).view(-1)
|
748 |
+
elif hasattr(self.grammarBuilder, 'transitionMatrix'):
|
749 |
+
return self.grammarBuilder.transitionMatrix(features).view(-1)
|
750 |
+
else:
|
751 |
+
assert False
|
752 |
+
else:
|
753 |
+
return self.grammarBuilder.logProductions(features)
|
754 |
+
|
755 |
+
def grammarFeatureLogProductionsOfTask(self, task):
|
756 |
+
return torch.tensor(self.grammarOfTask(task).untorch().featureVector())
|
757 |
+
|
758 |
+
def grammarLogProductionDistanceToTask(self, task, tasks):
|
759 |
+
"""Returns the cosine similarity of all other tasks to a given task."""
|
760 |
+
taskLogits = self.grammarLogProductionsOfTask(task).unsqueeze(0) # Change to [1, D]
|
761 |
+
assert taskLogits is not None, 'Grammar log productions are not defined for this task.'
|
762 |
+
otherTasks = [t for t in tasks if t is not task] # [nTasks -1 , D]
|
763 |
+
|
764 |
+
# Build matrix of all other tasks.
|
765 |
+
otherLogits = torch.stack([self.grammarLogProductionsOfTask(t) for t in otherTasks])
|
766 |
+
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
767 |
+
cosMatrix = cos(taskLogits, otherLogits)
|
768 |
+
return cosMatrix.data.cpu().numpy()
|
769 |
+
|
770 |
+
def grammarEntropyOfTask(self, task):
|
771 |
+
"""Returns the entropy of the grammar distribution from non-contextual models for a task."""
|
772 |
+
grammarLogProductionsOfTask = self.grammarLogProductionsOfTask(task)
|
773 |
+
|
774 |
+
if grammarLogProductionsOfTask is None: return None
|
775 |
+
|
776 |
+
if hasattr(self, 'entropy'):
|
777 |
+
return self.entropy(grammarLogProductionsOfTask)
|
778 |
+
else:
|
779 |
+
e = Entropy()
|
780 |
+
return e(grammarLogProductionsOfTask)
|
781 |
+
|
782 |
+
def taskAuxiliaryLossLayer(self, tasks):
|
783 |
+
return {task: self._auxiliaryPrediction(self.featureExtractor.featuresOfTask(task)).view(-1).data.cpu().numpy()
|
784 |
+
for task in tasks}
|
785 |
+
|
786 |
+
def taskGrammarFeatureLogProductions(self, tasks):
|
787 |
+
return {task: self.grammarFeatureLogProductionsOfTask(task).data.cpu().numpy()
|
788 |
+
for task in tasks}
|
789 |
+
|
790 |
+
def taskGrammarLogProductions(self, tasks):
|
791 |
+
return {task: self.grammarLogProductionsOfTask(task).data.cpu().numpy()
|
792 |
+
for task in tasks}
|
793 |
+
|
794 |
+
def taskGrammarStartProductions(self, tasks):
|
795 |
+
return {task: np.array([l for l,_1,_2 in g.productions ])
|
796 |
+
for task in tasks
|
797 |
+
for g in [self.grammarOfTask(task).untorch().noParent] }
|
798 |
+
|
799 |
+
def taskHiddenStates(self, tasks):
|
800 |
+
return {task: self._MLP(self.featureExtractor.featuresOfTask(task)).view(-1).data.cpu().numpy()
|
801 |
+
for task in tasks}
|
802 |
+
|
803 |
+
def taskGrammarEntropies(self, tasks):
|
804 |
+
return {task: self.grammarEntropyOfTask(task).data.cpu().numpy()
|
805 |
+
for task in tasks}
|
806 |
+
|
807 |
+
def frontierKL(self, frontier, auxiliary=False, vectorized=True):
|
808 |
+
features = self.featureExtractor.featuresOfTask(frontier.task)
|
809 |
+
if features is None:
|
810 |
+
return None, None
|
811 |
+
# Monte Carlo estimate: draw a sample from the frontier
|
812 |
+
entry = frontier.sample()
|
813 |
+
|
814 |
+
al = self.auxiliaryLoss(frontier, features if auxiliary else features.detach())
|
815 |
+
|
816 |
+
if not vectorized:
|
817 |
+
g = self(features)
|
818 |
+
return - entry.program.logLikelihood(g), al
|
819 |
+
else:
|
820 |
+
features = self._MLP(features).unsqueeze(0)
|
821 |
+
|
822 |
+
ll = self.grammarBuilder.batchedLogLikelihoods(features, [entry.program]).view(-1)
|
823 |
+
return -ll, al
|
824 |
+
|
825 |
+
|
826 |
+
def frontierBiasOptimal(self, frontier, auxiliary=False, vectorized=True):
|
827 |
+
if not vectorized:
|
828 |
+
features = self.featureExtractor.featuresOfTask(frontier.task)
|
829 |
+
if features is None: return None, None
|
830 |
+
al = self.auxiliaryLoss(frontier, features if auxiliary else features.detach())
|
831 |
+
g = self(features)
|
832 |
+
summaries = [entry.program for entry in frontier]
|
833 |
+
likelihoods = torch.cat([entry.program.logLikelihood(g) + entry.logLikelihood
|
834 |
+
for entry in frontier ])
|
835 |
+
best = likelihoods.max()
|
836 |
+
return -best, al
|
837 |
+
|
838 |
+
batchSize = len(frontier.entries)
|
839 |
+
features = self.featureExtractor.featuresOfTask(frontier.task)
|
840 |
+
if features is None: return None, None
|
841 |
+
al = self.auxiliaryLoss(frontier, features if auxiliary else features.detach())
|
842 |
+
features = self._MLP(features)
|
843 |
+
features = features.expand(batchSize, features.size(-1)) # TODO
|
844 |
+
lls = self.grammarBuilder.batchedLogLikelihoods(features, [entry.program for entry in frontier])
|
845 |
+
actual_ll = torch.Tensor([ entry.logLikelihood for entry in frontier])
|
846 |
+
lls = lls + (actual_ll.cuda() if self.use_cuda else actual_ll)
|
847 |
+
ml = -lls.max() #Beware that inputs to max change output type
|
848 |
+
return ml, al
|
849 |
+
|
850 |
+
def replaceProgramsWithLikelihoodSummaries(self, frontier):
|
851 |
+
return Frontier(
|
852 |
+
[FrontierEntry(
|
853 |
+
program=self.grammar.closedLikelihoodSummary(frontier.task.request, e.program),
|
854 |
+
logLikelihood=e.logLikelihood,
|
855 |
+
logPrior=e.logPrior) for e in frontier],
|
856 |
+
task=frontier.task)
|
857 |
+
|
858 |
+
def train(self, frontiers, _=None, steps=None, lr=0.001, topK=5, CPUs=1,
|
859 |
+
timeout=None, evaluationTimeout=0.001,
|
860 |
+
helmholtzFrontiers=[], helmholtzRatio=0., helmholtzBatch=500,
|
861 |
+
biasOptimal=None, defaultRequest=None, auxLoss=False, vectorized=True):
|
862 |
+
"""
|
863 |
+
helmholtzRatio: What fraction of the training data should be forward samples from the generative model?
|
864 |
+
helmholtzFrontiers: Frontiers from programs enumerated from generative model (optional)
|
865 |
+
If helmholtzFrontiers is not provided then we will sample programs during training
|
866 |
+
"""
|
867 |
+
assert (steps is not None) or (timeout is not None), \
|
868 |
+
"Cannot train recognition model without either a bound on the number of gradient steps or bound on the training time"
|
869 |
+
if steps is None: steps = 9999999
|
870 |
+
if biasOptimal is None: biasOptimal = len(helmholtzFrontiers) > 0
|
871 |
+
|
872 |
+
requests = [frontier.task.request for frontier in frontiers]
|
873 |
+
if len(requests) == 0 and helmholtzRatio > 0 and len(helmholtzFrontiers) == 0:
|
874 |
+
assert defaultRequest is not None, "You are trying to random Helmholtz training, but don't have any frontiers. Therefore we would not know the type of the program to sample. Try specifying defaultRequest=..."
|
875 |
+
requests = [defaultRequest]
|
876 |
+
frontiers = [frontier.topK(topK).normalize()
|
877 |
+
for frontier in frontiers if not frontier.empty]
|
878 |
+
if len(frontiers) == 0:
|
879 |
+
eprint("You didn't give me any nonempty replay frontiers to learn from. Going to learn from 100% Helmholtz samples")
|
880 |
+
helmholtzRatio = 1.
|
881 |
+
|
882 |
+
# Should we sample programs or use the enumerated programs?
|
883 |
+
randomHelmholtz = len(helmholtzFrontiers) == 0
|
884 |
+
|
885 |
+
class HelmholtzEntry:
|
886 |
+
def __init__(self, frontier, owner):
|
887 |
+
self.request = frontier.task.request
|
888 |
+
self.task = None
|
889 |
+
self.programs = [e.program for e in frontier]
|
890 |
+
self.frontier = Thunk(lambda: owner.replaceProgramsWithLikelihoodSummaries(frontier))
|
891 |
+
self.owner = owner
|
892 |
+
|
893 |
+
def clear(self): self.task = None
|
894 |
+
|
895 |
+
def calculateTask(self):
|
896 |
+
assert self.task is None
|
897 |
+
p = random.choice(self.programs)
|
898 |
+
return self.owner.featureExtractor.taskOfProgram(p, self.request)
|
899 |
+
|
900 |
+
def makeFrontier(self):
|
901 |
+
assert self.task is not None
|
902 |
+
f = Frontier(self.frontier.force().entries,
|
903 |
+
task=self.task)
|
904 |
+
return f
|
905 |
+
|
906 |
+
|
907 |
+
|
908 |
+
|
909 |
+
# Should we recompute tasks on the fly from Helmholtz? This
|
910 |
+
# should be done if the task is stochastic, or if there are
|
911 |
+
# different kinds of inputs on which it could be run. For
|
912 |
+
# example, lists and strings need this; towers and graphics do
|
913 |
+
# not. There is no harm in recomputed the tasks, it just
|
914 |
+
# wastes time.
|
915 |
+
if not hasattr(self.featureExtractor, 'recomputeTasks'):
|
916 |
+
self.featureExtractor.recomputeTasks = True
|
917 |
+
helmholtzFrontiers = [HelmholtzEntry(f, self)
|
918 |
+
for f in helmholtzFrontiers]
|
919 |
+
random.shuffle(helmholtzFrontiers)
|
920 |
+
|
921 |
+
helmholtzIndex = [0]
|
922 |
+
def getHelmholtz():
|
923 |
+
if randomHelmholtz:
|
924 |
+
if helmholtzIndex[0] >= len(helmholtzFrontiers):
|
925 |
+
updateHelmholtzTasks()
|
926 |
+
helmholtzIndex[0] = 0
|
927 |
+
return getHelmholtz()
|
928 |
+
helmholtzIndex[0] += 1
|
929 |
+
return helmholtzFrontiers[helmholtzIndex[0] - 1].makeFrontier()
|
930 |
+
|
931 |
+
f = helmholtzFrontiers[helmholtzIndex[0]]
|
932 |
+
if f.task is None:
|
933 |
+
with timing("Evaluated another batch of Helmholtz tasks"):
|
934 |
+
updateHelmholtzTasks()
|
935 |
+
return getHelmholtz()
|
936 |
+
|
937 |
+
helmholtzIndex[0] += 1
|
938 |
+
if helmholtzIndex[0] >= len(helmholtzFrontiers):
|
939 |
+
helmholtzIndex[0] = 0
|
940 |
+
random.shuffle(helmholtzFrontiers)
|
941 |
+
if self.featureExtractor.recomputeTasks:
|
942 |
+
for fp in helmholtzFrontiers:
|
943 |
+
fp.clear()
|
944 |
+
return getHelmholtz() # because we just cleared everything
|
945 |
+
assert f.task is not None
|
946 |
+
return f.makeFrontier()
|
947 |
+
|
948 |
+
def updateHelmholtzTasks():
|
949 |
+
updateCPUs = CPUs if hasattr(self.featureExtractor, 'parallelTaskOfProgram') and self.featureExtractor.parallelTaskOfProgram else 1
|
950 |
+
if updateCPUs > 1: eprint("Updating Helmholtz tasks with",updateCPUs,"CPUs",
|
951 |
+
"while using",getThisMemoryUsage(),"memory")
|
952 |
+
|
953 |
+
if randomHelmholtz:
|
954 |
+
newFrontiers = self.sampleManyHelmholtz(requests, helmholtzBatch, CPUs)
|
955 |
+
newEntries = []
|
956 |
+
for f in newFrontiers:
|
957 |
+
e = HelmholtzEntry(f,self)
|
958 |
+
e.task = f.task
|
959 |
+
newEntries.append(e)
|
960 |
+
helmholtzFrontiers.clear()
|
961 |
+
helmholtzFrontiers.extend(newEntries)
|
962 |
+
return
|
963 |
+
|
964 |
+
# Save some memory by freeing up the tasks as we go through them
|
965 |
+
if self.featureExtractor.recomputeTasks:
|
966 |
+
for hi in range(max(0, helmholtzIndex[0] - helmholtzBatch,
|
967 |
+
min(helmholtzIndex[0], len(helmholtzFrontiers)))):
|
968 |
+
helmholtzFrontiers[hi].clear()
|
969 |
+
|
970 |
+
if hasattr(self.featureExtractor, 'tasksOfPrograms'):
|
971 |
+
eprint("batching task calculation")
|
972 |
+
newTasks = self.featureExtractor.tasksOfPrograms(
|
973 |
+
[random.choice(hf.programs)
|
974 |
+
for hf in helmholtzFrontiers[helmholtzIndex[0]:helmholtzIndex[0] + helmholtzBatch] ],
|
975 |
+
[hf.request
|
976 |
+
for hf in helmholtzFrontiers[helmholtzIndex[0]:helmholtzIndex[0] + helmholtzBatch] ])
|
977 |
+
else:
|
978 |
+
newTasks = [hf.calculateTask()
|
979 |
+
for hf in helmholtzFrontiers[helmholtzIndex[0]:helmholtzIndex[0] + helmholtzBatch]]
|
980 |
+
|
981 |
+
"""
|
982 |
+
# catwong: Disabled for ensemble training.
|
983 |
+
newTasks = \
|
984 |
+
parallelMap(updateCPUs,
|
985 |
+
lambda f: f.calculateTask(),
|
986 |
+
helmholtzFrontiers[helmholtzIndex[0]:helmholtzIndex[0] + helmholtzBatch],
|
987 |
+
seedRandom=True)
|
988 |
+
"""
|
989 |
+
badIndices = []
|
990 |
+
endingIndex = min(helmholtzIndex[0] + helmholtzBatch, len(helmholtzFrontiers))
|
991 |
+
for i in range(helmholtzIndex[0], endingIndex):
|
992 |
+
helmholtzFrontiers[i].task = newTasks[i - helmholtzIndex[0]]
|
993 |
+
if helmholtzFrontiers[i].task is None: badIndices.append(i)
|
994 |
+
# Permanently kill anything which failed to give a task
|
995 |
+
for i in reversed(badIndices):
|
996 |
+
assert helmholtzFrontiers[i].task is None
|
997 |
+
del helmholtzFrontiers[i]
|
998 |
+
|
999 |
+
|
1000 |
+
# We replace each program in the frontier with its likelihoodSummary
|
1001 |
+
# This is because calculating likelihood summaries requires juggling types
|
1002 |
+
# And type stuff is expensive!
|
1003 |
+
frontiers = [self.replaceProgramsWithLikelihoodSummaries(f).normalize()
|
1004 |
+
for f in frontiers]
|
1005 |
+
|
1006 |
+
eprint("(ID=%d): Training a recognition model from %d frontiers, %d%% Helmholtz, feature extractor %s." % (
|
1007 |
+
self.id, len(frontiers), int(helmholtzRatio * 100), self.featureExtractor.__class__.__name__))
|
1008 |
+
eprint("(ID=%d): Got %d Helmholtz frontiers - random Helmholtz training? : %s"%(
|
1009 |
+
self.id, len(helmholtzFrontiers), len(helmholtzFrontiers) == 0))
|
1010 |
+
eprint("(ID=%d): Contextual? %s" % (self.id, str(self.contextual)))
|
1011 |
+
eprint("(ID=%d): Bias optimal? %s" % (self.id, str(biasOptimal)))
|
1012 |
+
eprint(f"(ID={self.id}): Aux loss? {auxLoss} (n.b. we train a 'auxiliary' classifier anyway - this controls if gradients propagate back to the future extractor)")
|
1013 |
+
|
1014 |
+
# The number of Helmholtz samples that we generate at once
|
1015 |
+
# Should only affect performance and shouldn't affect anything else
|
1016 |
+
helmholtzSamples = []
|
1017 |
+
|
1018 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=lr, eps=1e-3, amsgrad=True)
|
1019 |
+
start = time.time()
|
1020 |
+
losses, descriptionLengths, realLosses, dreamLosses, realMDL, dreamMDL = [], [], [], [], [], []
|
1021 |
+
classificationLosses = []
|
1022 |
+
totalGradientSteps = 0
|
1023 |
+
epochs = 9999999
|
1024 |
+
for i in range(1, epochs + 1):
|
1025 |
+
if timeout and time.time() - start > timeout:
|
1026 |
+
break
|
1027 |
+
|
1028 |
+
if totalGradientSteps > steps:
|
1029 |
+
break
|
1030 |
+
|
1031 |
+
if helmholtzRatio < 1.:
|
1032 |
+
permutedFrontiers = list(frontiers)
|
1033 |
+
random.shuffle(permutedFrontiers)
|
1034 |
+
else:
|
1035 |
+
permutedFrontiers = [None]
|
1036 |
+
|
1037 |
+
finishedSteps = False
|
1038 |
+
for frontier in permutedFrontiers:
|
1039 |
+
# Randomly decide whether to sample from the generative model
|
1040 |
+
dreaming = random.random() < helmholtzRatio
|
1041 |
+
if dreaming: frontier = getHelmholtz()
|
1042 |
+
self.zero_grad()
|
1043 |
+
loss, classificationLoss = \
|
1044 |
+
self.frontierBiasOptimal(frontier, auxiliary=auxLoss, vectorized=vectorized) if biasOptimal \
|
1045 |
+
else self.frontierKL(frontier, auxiliary=auxLoss, vectorized=vectorized)
|
1046 |
+
if loss is None:
|
1047 |
+
if not dreaming:
|
1048 |
+
eprint("ERROR: Could not extract features during experience replay.")
|
1049 |
+
eprint("Task is:",frontier.task)
|
1050 |
+
eprint("Aborting - we need to be able to extract features of every actual task.")
|
1051 |
+
assert False
|
1052 |
+
else:
|
1053 |
+
continue
|
1054 |
+
if is_torch_invalid(loss):
|
1055 |
+
eprint("Invalid real-data loss!")
|
1056 |
+
else:
|
1057 |
+
(loss + classificationLoss).backward()
|
1058 |
+
classificationLosses.append(classificationLoss.data.item())
|
1059 |
+
optimizer.step()
|
1060 |
+
totalGradientSteps += 1
|
1061 |
+
losses.append(loss.data.item())
|
1062 |
+
descriptionLengths.append(min(-e.logPrior for e in frontier))
|
1063 |
+
if dreaming:
|
1064 |
+
dreamLosses.append(losses[-1])
|
1065 |
+
dreamMDL.append(descriptionLengths[-1])
|
1066 |
+
else:
|
1067 |
+
realLosses.append(losses[-1])
|
1068 |
+
realMDL.append(descriptionLengths[-1])
|
1069 |
+
if totalGradientSteps > steps:
|
1070 |
+
break # Stop iterating, then print epoch and loss, then break to finish.
|
1071 |
+
|
1072 |
+
if (i == 1 or i % 10 == 0) and losses:
|
1073 |
+
eprint("(ID=%d): " % self.id, "Epoch", i, "Loss", mean(losses))
|
1074 |
+
if realLosses and dreamLosses:
|
1075 |
+
eprint("(ID=%d): " % self.id, "\t\t(real loss): ", mean(realLosses), "\t(dream loss):", mean(dreamLosses))
|
1076 |
+
eprint("(ID=%d): " % self.id, "\tvs MDL (w/o neural net)", mean(descriptionLengths))
|
1077 |
+
if realMDL and dreamMDL:
|
1078 |
+
eprint("\t\t(real MDL): ", mean(realMDL), "\t(dream MDL):", mean(dreamMDL))
|
1079 |
+
eprint("(ID=%d): " % self.id, "\t%d cumulative gradient steps. %f steps/sec"%(totalGradientSteps,
|
1080 |
+
totalGradientSteps/(time.time() - start)))
|
1081 |
+
eprint("(ID=%d): " % self.id, "\t%d-way auxiliary classification loss"%len(self.grammar.primitives),sum(classificationLosses)/len(classificationLosses))
|
1082 |
+
losses, descriptionLengths, realLosses, dreamLosses, realMDL, dreamMDL = [], [], [], [], [], []
|
1083 |
+
classificationLosses = []
|
1084 |
+
gc.collect()
|
1085 |
+
|
1086 |
+
eprint("(ID=%d): " % self.id, " Trained recognition model in",time.time() - start,"seconds")
|
1087 |
+
self.trained=True
|
1088 |
+
return self
|
1089 |
+
|
1090 |
+
def sampleHelmholtz(self, requests, statusUpdate=None, seed=None):
|
1091 |
+
if seed is not None:
|
1092 |
+
random.seed(seed)
|
1093 |
+
request = random.choice(requests)
|
1094 |
+
|
1095 |
+
program = self.generativeModel.sample(request, maximumDepth=6, maxAttempts=100)
|
1096 |
+
if program is None:
|
1097 |
+
return None
|
1098 |
+
task = self.featureExtractor.taskOfProgram(program, request)
|
1099 |
+
|
1100 |
+
if statusUpdate is not None:
|
1101 |
+
flushEverything()
|
1102 |
+
if task is None:
|
1103 |
+
return None
|
1104 |
+
|
1105 |
+
if hasattr(self.featureExtractor, 'lexicon'):
|
1106 |
+
if self.featureExtractor.tokenize(task.examples) is None:
|
1107 |
+
return None
|
1108 |
+
|
1109 |
+
ll = self.generativeModel.logLikelihood(request, program)
|
1110 |
+
frontier = Frontier([FrontierEntry(program=program,
|
1111 |
+
logLikelihood=0., logPrior=ll)],
|
1112 |
+
task=task)
|
1113 |
+
return frontier
|
1114 |
+
|
1115 |
+
def sampleManyHelmholtz(self, requests, N, CPUs):
|
1116 |
+
eprint("Sampling %d programs from the prior on %d CPUs..." % (N, CPUs))
|
1117 |
+
flushEverything()
|
1118 |
+
frequency = N / 50
|
1119 |
+
startingSeed = random.random()
|
1120 |
+
|
1121 |
+
# Sequentially for ensemble training.
|
1122 |
+
samples = [self.sampleHelmholtz(requests,
|
1123 |
+
statusUpdate='.' if n % frequency == 0 else None,
|
1124 |
+
seed=startingSeed + n) for n in range(N)]
|
1125 |
+
|
1126 |
+
# (cathywong) Disabled for ensemble training.
|
1127 |
+
# samples = parallelMap(
|
1128 |
+
# 1,
|
1129 |
+
# lambda n: self.sampleHelmholtz(requests,
|
1130 |
+
# statusUpdate='.' if n % frequency == 0 else None,
|
1131 |
+
# seed=startingSeed + n),
|
1132 |
+
# range(N))
|
1133 |
+
eprint()
|
1134 |
+
flushEverything()
|
1135 |
+
samples = [z for z in samples if z is not None]
|
1136 |
+
eprint()
|
1137 |
+
eprint("Got %d/%d valid samples." % (len(samples), N))
|
1138 |
+
flushEverything()
|
1139 |
+
|
1140 |
+
return samples
|
1141 |
+
|
1142 |
+
def enumerateFrontiers(self,
|
1143 |
+
tasks,
|
1144 |
+
enumerationTimeout=None,
|
1145 |
+
testing=False,
|
1146 |
+
solver=None,
|
1147 |
+
CPUs=1,
|
1148 |
+
frontierSize=None,
|
1149 |
+
maximumFrontier=None,
|
1150 |
+
evaluationTimeout=None):
|
1151 |
+
with timing("Evaluated recognition model"):
|
1152 |
+
grammars = {task: self.grammarOfTask(task)
|
1153 |
+
for task in tasks}
|
1154 |
+
#untorch seperately to make sure you filter out None grammars
|
1155 |
+
grammars = {task: grammar.untorch() for task, grammar in grammars.items() if grammar is not None}
|
1156 |
+
|
1157 |
+
return multicoreEnumeration(grammars, tasks,
|
1158 |
+
testing=testing,
|
1159 |
+
solver=solver,
|
1160 |
+
enumerationTimeout=enumerationTimeout,
|
1161 |
+
CPUs=CPUs, maximumFrontier=maximumFrontier,
|
1162 |
+
evaluationTimeout=evaluationTimeout)
|
1163 |
+
|
1164 |
+
|
1165 |
+
class RecurrentFeatureExtractor(nn.Module):
|
1166 |
+
def __init__(self, _=None,
|
1167 |
+
tasks=None,
|
1168 |
+
cuda=False,
|
1169 |
+
# what are the symbols that can occur in the inputs and
|
1170 |
+
# outputs
|
1171 |
+
lexicon=None,
|
1172 |
+
# how many hidden units
|
1173 |
+
H=32,
|
1174 |
+
# Should the recurrent units be bidirectional?
|
1175 |
+
bidirectional=False,
|
1176 |
+
# What should be the timeout for trying to construct Helmholtz tasks?
|
1177 |
+
helmholtzTimeout=0.25,
|
1178 |
+
# What should be the timeout for running a Helmholtz program?
|
1179 |
+
helmholtzEvaluationTimeout=0.01):
|
1180 |
+
super(RecurrentFeatureExtractor, self).__init__()
|
1181 |
+
|
1182 |
+
assert tasks is not None, "You must provide a list of all of the tasks, both those that have been hit and those that have not been hit. Input examples are sampled from these tasks."
|
1183 |
+
|
1184 |
+
# maps from a requesting type to all of the inputs that we ever saw with that request
|
1185 |
+
self.requestToInputs = {
|
1186 |
+
tp: [list(map(fst, t.examples)) for t in tasks if t.request == tp ]
|
1187 |
+
for tp in {t.request for t in tasks}
|
1188 |
+
}
|
1189 |
+
|
1190 |
+
inputTypes = {t
|
1191 |
+
for task in tasks
|
1192 |
+
for t in task.request.functionArguments()}
|
1193 |
+
# maps from a type to all of the inputs that we ever saw having that type
|
1194 |
+
self.argumentsWithType = {
|
1195 |
+
tp: [ x
|
1196 |
+
for t in tasks
|
1197 |
+
for xs,_ in t.examples
|
1198 |
+
for tpp, x in zip(t.request.functionArguments(), xs)
|
1199 |
+
if tpp == tp]
|
1200 |
+
for tp in inputTypes
|
1201 |
+
}
|
1202 |
+
self.requestToNumberOfExamples = {
|
1203 |
+
tp: [ len(t.examples)
|
1204 |
+
for t in tasks if t.request == tp ]
|
1205 |
+
for tp in {t.request for t in tasks}
|
1206 |
+
}
|
1207 |
+
self.helmholtzTimeout = helmholtzTimeout
|
1208 |
+
self.helmholtzEvaluationTimeout = helmholtzEvaluationTimeout
|
1209 |
+
self.parallelTaskOfProgram = True
|
1210 |
+
|
1211 |
+
assert lexicon
|
1212 |
+
self.specialSymbols = [
|
1213 |
+
"STARTING", # start of entire sequence
|
1214 |
+
"ENDING", # ending of entire sequence
|
1215 |
+
"STARTOFOUTPUT", # begins the start of the output
|
1216 |
+
"ENDOFINPUT" # delimits the ending of an input - we might have multiple inputs
|
1217 |
+
]
|
1218 |
+
lexicon += self.specialSymbols
|
1219 |
+
encoder = nn.Embedding(len(lexicon), H)
|
1220 |
+
self.encoder = encoder
|
1221 |
+
|
1222 |
+
self.H = H
|
1223 |
+
self.bidirectional = bidirectional
|
1224 |
+
|
1225 |
+
layers = 1
|
1226 |
+
|
1227 |
+
model = nn.GRU(H, H, layers, bidirectional=bidirectional)
|
1228 |
+
self.model = model
|
1229 |
+
|
1230 |
+
self.use_cuda = cuda
|
1231 |
+
self.lexicon = lexicon
|
1232 |
+
self.symbolToIndex = {
|
1233 |
+
symbol: index for index,
|
1234 |
+
symbol in enumerate(lexicon)}
|
1235 |
+
self.startingIndex = self.symbolToIndex["STARTING"]
|
1236 |
+
self.endingIndex = self.symbolToIndex["ENDING"]
|
1237 |
+
self.startOfOutputIndex = self.symbolToIndex["STARTOFOUTPUT"]
|
1238 |
+
self.endOfInputIndex = self.symbolToIndex["ENDOFINPUT"]
|
1239 |
+
|
1240 |
+
# Maximum number of inputs/outputs we will run the recognition
|
1241 |
+
# model on per task
|
1242 |
+
# This is an optimization hack
|
1243 |
+
self.MAXINPUTS = 100
|
1244 |
+
|
1245 |
+
if cuda: self.cuda()
|
1246 |
+
|
1247 |
+
@property
|
1248 |
+
def outputDimensionality(self): return self.H
|
1249 |
+
|
1250 |
+
# modify examples before forward (to turn them into iterables of lexicon)
|
1251 |
+
# you should override this if needed
|
1252 |
+
def tokenize(self, x): return x
|
1253 |
+
|
1254 |
+
def symbolEmbeddings(self):
|
1255 |
+
return {s: self.encoder(variable([self.symbolToIndex[s]])).squeeze(
|
1256 |
+
0).data.cpu().numpy() for s in self.lexicon if not (s in self.specialSymbols)}
|
1257 |
+
|
1258 |
+
def packExamples(self, examples):
|
1259 |
+
"""IMPORTANT! xs must be sorted in decreasing order of size because pytorch is stupid"""
|
1260 |
+
es = []
|
1261 |
+
sizes = []
|
1262 |
+
for xs, y in examples:
|
1263 |
+
e = [self.startingIndex]
|
1264 |
+
for x in xs:
|
1265 |
+
for s in x:
|
1266 |
+
e.append(self.symbolToIndex[s])
|
1267 |
+
e.append(self.endOfInputIndex)
|
1268 |
+
e.append(self.startOfOutputIndex)
|
1269 |
+
for s in y:
|
1270 |
+
e.append(self.symbolToIndex[s])
|
1271 |
+
e.append(self.endingIndex)
|
1272 |
+
if es != []:
|
1273 |
+
assert len(e) <= len(es[-1]), \
|
1274 |
+
"Examples must be sorted in decreasing order of their tokenized size. This should be transparently handled in recognition.py, so if this assertion fails it isn't your fault as a user of EC but instead is a bug inside of EC."
|
1275 |
+
es.append(e)
|
1276 |
+
sizes.append(len(e))
|
1277 |
+
|
1278 |
+
m = max(sizes)
|
1279 |
+
# padding
|
1280 |
+
for j, e in enumerate(es):
|
1281 |
+
es[j] += [self.endingIndex] * (m - len(e))
|
1282 |
+
|
1283 |
+
x = variable(es, cuda=self.use_cuda)
|
1284 |
+
x = self.encoder(x)
|
1285 |
+
# x: (batch size, maximum length, E)
|
1286 |
+
x = x.permute(1, 0, 2)
|
1287 |
+
# x: TxBxE
|
1288 |
+
x = pack_padded_sequence(x, sizes)
|
1289 |
+
return x, sizes
|
1290 |
+
|
1291 |
+
def examplesEncoding(self, examples):
|
1292 |
+
examples = sorted(examples, key=lambda xs_y: sum(
|
1293 |
+
len(z) + 1 for z in xs_y[0]) + len(xs_y[1]), reverse=True)
|
1294 |
+
x, sizes = self.packExamples(examples)
|
1295 |
+
outputs, hidden = self.model(x)
|
1296 |
+
# outputs, sizes = pad_packed_sequence(outputs)
|
1297 |
+
# I don't know whether to return the final output or the final hidden
|
1298 |
+
# activations...
|
1299 |
+
return hidden[0, :, :] + hidden[1, :, :]
|
1300 |
+
|
1301 |
+
def forward(self, examples):
|
1302 |
+
tokenized = self.tokenize(examples)
|
1303 |
+
if not tokenized:
|
1304 |
+
return None
|
1305 |
+
|
1306 |
+
if hasattr(self, 'MAXINPUTS') and len(tokenized) > self.MAXINPUTS:
|
1307 |
+
tokenized = list(tokenized)
|
1308 |
+
random.shuffle(tokenized)
|
1309 |
+
tokenized = tokenized[:self.MAXINPUTS]
|
1310 |
+
e = self.examplesEncoding(tokenized)
|
1311 |
+
# max pool
|
1312 |
+
# e,_ = e.max(dim = 0)
|
1313 |
+
|
1314 |
+
# take the average activations across all of the examples
|
1315 |
+
# I think this might be better because we might be testing on data
|
1316 |
+
# which has far more o far fewer examples then training
|
1317 |
+
e = e.mean(dim=0)
|
1318 |
+
return e
|
1319 |
+
|
1320 |
+
def featuresOfTask(self, t):
|
1321 |
+
if hasattr(self, 'useFeatures'):
|
1322 |
+
f = self(t.features)
|
1323 |
+
else:
|
1324 |
+
# Featurize the examples directly.
|
1325 |
+
f = self(t.examples)
|
1326 |
+
return f
|
1327 |
+
|
1328 |
+
def taskOfProgram(self, p, tp):
|
1329 |
+
# half of the time we randomly mix together inputs
|
1330 |
+
# this gives better generalization on held out tasks
|
1331 |
+
# the other half of the time we train on sets of inputs in the training data
|
1332 |
+
# this gives better generalization on unsolved training tasks
|
1333 |
+
if random.random() < 0.5:
|
1334 |
+
def randomInput(t): return random.choice(self.argumentsWithType[t])
|
1335 |
+
# Loop over the inputs in a random order and pick the first ones that
|
1336 |
+
# doesn't generate an exception
|
1337 |
+
|
1338 |
+
startTime = time.time()
|
1339 |
+
examples = []
|
1340 |
+
while True:
|
1341 |
+
# TIMEOUT! this must not be a very good program
|
1342 |
+
if time.time() - startTime > self.helmholtzTimeout: return None
|
1343 |
+
|
1344 |
+
# Grab some random inputs
|
1345 |
+
xs = [randomInput(t) for t in tp.functionArguments()]
|
1346 |
+
try:
|
1347 |
+
y = runWithTimeout(lambda: p.runWithArguments(xs), self.helmholtzEvaluationTimeout)
|
1348 |
+
examples.append((tuple(xs),y))
|
1349 |
+
if len(examples) >= random.choice(self.requestToNumberOfExamples[tp]):
|
1350 |
+
return Task("Helmholtz", tp, examples)
|
1351 |
+
except: continue
|
1352 |
+
|
1353 |
+
else:
|
1354 |
+
candidateInputs = list(self.requestToInputs[tp])
|
1355 |
+
random.shuffle(candidateInputs)
|
1356 |
+
for xss in candidateInputs:
|
1357 |
+
ys = []
|
1358 |
+
for xs in xss:
|
1359 |
+
try: y = runWithTimeout(lambda: p.runWithArguments(xs), self.helmholtzEvaluationTimeout)
|
1360 |
+
except: break
|
1361 |
+
ys.append(y)
|
1362 |
+
if len(ys) == len(xss):
|
1363 |
+
return Task("Helmholtz", tp, list(zip(xss, ys)))
|
1364 |
+
return None
|
1365 |
+
|
1366 |
+
|
1367 |
+
|
1368 |
+
class LowRank(nn.Module):
|
1369 |
+
"""
|
1370 |
+
Module that outputs a rank R matrix of size m by n from input of size i.
|
1371 |
+
"""
|
1372 |
+
def __init__(self, i, m, n, r):
|
1373 |
+
"""
|
1374 |
+
i: input dimension
|
1375 |
+
m: output rows
|
1376 |
+
n: output columns
|
1377 |
+
r: maximum rank. if this is None then the output will be full-rank
|
1378 |
+
"""
|
1379 |
+
super(LowRank, self).__init__()
|
1380 |
+
|
1381 |
+
self.m = m
|
1382 |
+
self.n = n
|
1383 |
+
|
1384 |
+
maximumPossibleRank = min(m, n)
|
1385 |
+
if r is None: r = maximumPossibleRank
|
1386 |
+
|
1387 |
+
if r < maximumPossibleRank:
|
1388 |
+
self.factored = True
|
1389 |
+
self.A = nn.Linear(i, m*r)
|
1390 |
+
self.B = nn.Linear(i, n*r)
|
1391 |
+
self.r = r
|
1392 |
+
else:
|
1393 |
+
self.factored = False
|
1394 |
+
self.M = nn.Linear(i, m*n)
|
1395 |
+
|
1396 |
+
def forward(self, x):
|
1397 |
+
sz = x.size()
|
1398 |
+
if len(sz) == 1:
|
1399 |
+
B = 1
|
1400 |
+
x = x.unsqueeze(0)
|
1401 |
+
needToSqueeze = True
|
1402 |
+
elif len(sz) == 2:
|
1403 |
+
B = sz[0]
|
1404 |
+
needToSqueeze = False
|
1405 |
+
else:
|
1406 |
+
assert False, "LowRank expects either a 1-dimensional tensor or a 2-dimensional tensor"
|
1407 |
+
|
1408 |
+
if self.factored:
|
1409 |
+
a = self.A(x).view(B, self.m, self.r)
|
1410 |
+
b = self.B(x).view(B, self.r, self.n)
|
1411 |
+
y = a @ b
|
1412 |
+
else:
|
1413 |
+
y = self.M(x).view(B, self.m, self.n)
|
1414 |
+
if needToSqueeze:
|
1415 |
+
y = y.squeeze(0)
|
1416 |
+
return y
|
1417 |
+
|
1418 |
+
|
1419 |
+
|
1420 |
+
|
1421 |
+
class DummyFeatureExtractor(nn.Module):
|
1422 |
+
def __init__(self, tasks, testingTasks=[], cuda=False):
|
1423 |
+
super(DummyFeatureExtractor, self).__init__()
|
1424 |
+
self.outputDimensionality = 1
|
1425 |
+
self.recomputeTasks = False
|
1426 |
+
def featuresOfTask(self, t):
|
1427 |
+
return variable([0.]).float()
|
1428 |
+
def featuresOfTasks(self, ts):
|
1429 |
+
return variable([[0.]]*len(ts)).float()
|
1430 |
+
def taskOfProgram(self, p, t):
|
1431 |
+
return Task("dummy task", t, [])
|
1432 |
+
|
1433 |
+
class RandomFeatureExtractor(nn.Module):
|
1434 |
+
def __init__(self, tasks):
|
1435 |
+
super(RandomFeatureExtractor, self).__init__()
|
1436 |
+
self.outputDimensionality = 1
|
1437 |
+
self.recomputeTasks = False
|
1438 |
+
def featuresOfTask(self, t):
|
1439 |
+
return variable([random.random()]).float()
|
1440 |
+
def featuresOfTasks(self, ts):
|
1441 |
+
return variable([[random.random()] for _ in range(len(ts)) ]).float()
|
1442 |
+
def taskOfProgram(self, p, t):
|
1443 |
+
return Task("dummy task", t, [])
|
1444 |
+
|
1445 |
+
class Flatten(nn.Module):
|
1446 |
+
def __init__(self):
|
1447 |
+
super(Flatten, self).__init__()
|
1448 |
+
|
1449 |
+
def forward(self, x):
|
1450 |
+
return x.view(x.size(0), -1)
|
1451 |
+
|
1452 |
+
class ImageFeatureExtractor(nn.Module):
|
1453 |
+
def __init__(self, inputImageDimension, resizedDimension=None,
|
1454 |
+
channels=1):
|
1455 |
+
super(ImageFeatureExtractor, self).__init__()
|
1456 |
+
|
1457 |
+
self.resizedDimension = resizedDimension or inputImageDimension
|
1458 |
+
self.inputImageDimension = inputImageDimension
|
1459 |
+
self.channels = channels
|
1460 |
+
|
1461 |
+
def conv_block(in_channels, out_channels):
|
1462 |
+
return nn.Sequential(
|
1463 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
1464 |
+
# nn.BatchNorm2d(out_channels),
|
1465 |
+
nn.ReLU(),
|
1466 |
+
nn.MaxPool2d(2)
|
1467 |
+
)
|
1468 |
+
|
1469 |
+
# channels for hidden
|
1470 |
+
hid_dim = 64
|
1471 |
+
z_dim = 64
|
1472 |
+
|
1473 |
+
self.encoder = nn.Sequential(
|
1474 |
+
conv_block(channels, hid_dim),
|
1475 |
+
conv_block(hid_dim, hid_dim),
|
1476 |
+
conv_block(hid_dim, hid_dim),
|
1477 |
+
conv_block(hid_dim, z_dim),
|
1478 |
+
Flatten()
|
1479 |
+
)
|
1480 |
+
|
1481 |
+
# Each layer of the encoder halves the dimension, except for the last layer which flattens
|
1482 |
+
outputImageDimensionality = self.resizedDimension/(2**(len(self.encoder) - 1))
|
1483 |
+
self.outputDimensionality = int(z_dim*outputImageDimensionality*outputImageDimensionality)
|
1484 |
+
|
1485 |
+
def forward(self, v):
|
1486 |
+
"""1 channel: v: BxWxW or v:WxW
|
1487 |
+
> 1 channel: v: BxCxWxW or v:CxWxW"""
|
1488 |
+
|
1489 |
+
insertBatch = False
|
1490 |
+
variabled = variable(v).float()
|
1491 |
+
if self.channels == 1: # insert channel dimension
|
1492 |
+
if len(variabled.shape) == 3: # batching
|
1493 |
+
variabled = variabled[:,None,:,:]
|
1494 |
+
elif len(variabled.shape) == 2: # no batching
|
1495 |
+
variabled = variabled[None,:,:]
|
1496 |
+
insertBatch = True
|
1497 |
+
else: assert False
|
1498 |
+
else: # expect to have a channel dimension
|
1499 |
+
if len(variabled.shape) == 4:
|
1500 |
+
pass
|
1501 |
+
elif len(variabled.shape) == 3:
|
1502 |
+
insertBatch = True
|
1503 |
+
else: assert False
|
1504 |
+
|
1505 |
+
if insertBatch: variabled = torch.unsqueeze(variabled, 0)
|
1506 |
+
|
1507 |
+
y = self.encoder(variabled)
|
1508 |
+
if insertBatch: y = y[0,:]
|
1509 |
+
return y
|
1510 |
+
|
1511 |
+
class JSONFeatureExtractor(object):
|
1512 |
+
def __init__(self, tasks, cudaFalse):
|
1513 |
+
# self.averages, self.deviations = Task.featureMeanAndStandardDeviation(tasks)
|
1514 |
+
# self.outputDimensionality = len(self.averages)
|
1515 |
+
self.cuda = cuda
|
1516 |
+
self.tasks = tasks
|
1517 |
+
|
1518 |
+
def stringify(self, x):
|
1519 |
+
# No whitespace #maybe kill the seperators
|
1520 |
+
return json.dumps(x, separators=(',', ':'))
|
1521 |
+
|
1522 |
+
def featuresOfTask(self, t):
|
1523 |
+
# >>> t.request to get the type
|
1524 |
+
# >>> t.examples to get input/output examples
|
1525 |
+
# this might actually be okay, because the input should just be nothing
|
1526 |
+
#return [(self.stringify(inputs), self.stringify(output))
|
1527 |
+
# for (inputs, output) in t.examples]
|
1528 |
+
return [(list(output),) for (inputs, output) in t.examples]
|
dreamcoder/task.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.program import *
|
2 |
+
from dreamcoder.differentiation import *
|
3 |
+
|
4 |
+
import signal
|
5 |
+
|
6 |
+
|
7 |
+
class EvaluationTimeout(Exception):
|
8 |
+
pass
|
9 |
+
|
10 |
+
|
11 |
+
EVALUATIONTABLE = {}
|
12 |
+
|
13 |
+
|
14 |
+
class Task(object):
|
15 |
+
def __init__(self, name, request, examples, features=None, cache=False):
|
16 |
+
'''request: the type of this task
|
17 |
+
examples: list of tuples of (input, output). input should be a tuple, with one entry for each argument
|
18 |
+
cache: should program evaluations be cached?
|
19 |
+
features: list of floats.'''
|
20 |
+
self.cache = cache
|
21 |
+
self.features = features
|
22 |
+
self.request = request
|
23 |
+
self.name = name
|
24 |
+
self.examples = examples
|
25 |
+
if len(self.examples) > 0:
|
26 |
+
assert all(len(xs) == len(examples[0][0])
|
27 |
+
for xs, _ in examples), \
|
28 |
+
"(for task %s) FATAL: Number of arguments varies." % name
|
29 |
+
|
30 |
+
def __str__(self):
|
31 |
+
if self.supervision is None:
|
32 |
+
return self.name
|
33 |
+
else:
|
34 |
+
return self.name + " (%s)"%self.supervision
|
35 |
+
|
36 |
+
def __repr__(self):
|
37 |
+
return "Task(name={self.name}, request={self.request}, examples={self.examples}"\
|
38 |
+
.format(self=self)
|
39 |
+
|
40 |
+
def __eq__(self, o): return self.name == o.name
|
41 |
+
|
42 |
+
def __ne__(self, o): return not (self == o)
|
43 |
+
|
44 |
+
def __hash__(self): return hash(self.name)
|
45 |
+
|
46 |
+
def describe(self):
|
47 |
+
description = ["%s : %s" % (self.name, self.request)]
|
48 |
+
for xs, y in self.examples:
|
49 |
+
if len(xs) == 1:
|
50 |
+
description.append("f(%s) = %s" % (xs[0], y))
|
51 |
+
else:
|
52 |
+
description.append("f%s = %s" % (xs, y))
|
53 |
+
return "\n".join(description)
|
54 |
+
|
55 |
+
def predict(self, f, x):
|
56 |
+
for a in x:
|
57 |
+
f = f(a)
|
58 |
+
return f
|
59 |
+
|
60 |
+
@property
|
61 |
+
def supervision(self):
|
62 |
+
if not hasattr(self, 'supervisedSolution'): return None
|
63 |
+
return self.supervisedSolution
|
64 |
+
|
65 |
+
def check(self, e, timeout=None):
|
66 |
+
if timeout is not None:
|
67 |
+
def timeoutCallBack(_1, _2): raise EvaluationTimeout()
|
68 |
+
try:
|
69 |
+
signal.signal(signal.SIGVTALRM, timeoutCallBack)
|
70 |
+
signal.setitimer(signal.ITIMER_VIRTUAL, timeout)
|
71 |
+
|
72 |
+
try:
|
73 |
+
f = e.evaluate([])
|
74 |
+
except IndexError:
|
75 |
+
# free variable
|
76 |
+
return False
|
77 |
+
except Exception as e:
|
78 |
+
eprint("Exception during evaluation:", e)
|
79 |
+
return False
|
80 |
+
|
81 |
+
for x, y in self.examples:
|
82 |
+
if self.cache and (x, e) in EVALUATIONTABLE:
|
83 |
+
p = EVALUATIONTABLE[(x, e)]
|
84 |
+
else:
|
85 |
+
try:
|
86 |
+
p = self.predict(f, x)
|
87 |
+
except BaseException:
|
88 |
+
p = None
|
89 |
+
if self.cache:
|
90 |
+
EVALUATIONTABLE[(x, e)] = p
|
91 |
+
if p != y:
|
92 |
+
if timeout is not None:
|
93 |
+
signal.signal(signal.SIGVTALRM, lambda *_: None)
|
94 |
+
signal.setitimer(signal.ITIMER_VIRTUAL, 0)
|
95 |
+
return False
|
96 |
+
|
97 |
+
return True
|
98 |
+
# except e:
|
99 |
+
# eprint(e)
|
100 |
+
# assert(False)
|
101 |
+
except EvaluationTimeout:
|
102 |
+
eprint("Timed out while evaluating", e)
|
103 |
+
return False
|
104 |
+
finally:
|
105 |
+
if timeout is not None:
|
106 |
+
signal.signal(signal.SIGVTALRM, lambda *_: None)
|
107 |
+
signal.setitimer(signal.ITIMER_VIRTUAL, 0)
|
108 |
+
|
109 |
+
def logLikelihood(self, e, timeout=None):
|
110 |
+
if self.check(e, timeout):
|
111 |
+
return 0.0
|
112 |
+
else:
|
113 |
+
return NEGATIVEINFINITY
|
114 |
+
|
115 |
+
@staticmethod
|
116 |
+
def featureMeanAndStandardDeviation(tasks):
|
117 |
+
dimension = len(tasks[0].features)
|
118 |
+
averages = [sum(t.features[j] for t in tasks) / float(len(tasks))
|
119 |
+
for j in range(dimension)]
|
120 |
+
variances = [sum((t.features[j] -
|
121 |
+
averages[j])**2 for t in tasks) /
|
122 |
+
float(len(tasks)) for j in range(dimension)]
|
123 |
+
standardDeviations = [v**0.5 for v in variances]
|
124 |
+
for j, s in enumerate(standardDeviations):
|
125 |
+
if s == 0.:
|
126 |
+
eprint(
|
127 |
+
"WARNING: Feature %d is always %f" %
|
128 |
+
(j + 1, averages[j]))
|
129 |
+
return averages, standardDeviations
|
130 |
+
|
131 |
+
def as_json_dict(self):
|
132 |
+
return {
|
133 |
+
"name": self.name,
|
134 |
+
"request": str(self.request),
|
135 |
+
"examples": [{"inputs": x, "output": y} for x, y in self.examples]
|
136 |
+
}
|
137 |
+
|
138 |
+
|
139 |
+
class DifferentiableTask(Task):
|
140 |
+
|
141 |
+
def __init__(self, name, request, examples, _=None,
|
142 |
+
features=None, BIC=1., loss=None, likelihoodThreshold=None,
|
143 |
+
steps=50, restarts=300, lr=0.5, decay=0.5, grow=1.2, actualParameters=None,
|
144 |
+
temperature=1., maxParameters=None, clipLoss=None, clipOutput=None):
|
145 |
+
assert loss is not None
|
146 |
+
self.temperature = temperature
|
147 |
+
self.actualParameters = actualParameters
|
148 |
+
self.maxParameters = maxParameters
|
149 |
+
self.loss = loss
|
150 |
+
self.BIC = BIC
|
151 |
+
self.likelihoodThreshold = likelihoodThreshold
|
152 |
+
|
153 |
+
arguments = {"parameterPenalty": BIC * math.log(len(examples)),
|
154 |
+
"temperature": temperature,
|
155 |
+
"steps": steps, "restarts": restarts, "lr": lr, "decay": decay, "grow": grow,
|
156 |
+
"maxParameters": maxParameters,
|
157 |
+
"lossThreshold": -likelihoodThreshold}
|
158 |
+
if clipLoss is not None: arguments['clipLoss'] = float(clipLoss)
|
159 |
+
if clipOutput is not None: arguments['clipOutput'] = float(clipOutput)
|
160 |
+
if actualParameters is not None: arguments['actualParameters'] = int(actualParameters)
|
161 |
+
|
162 |
+
self.specialTask = ("differentiable",
|
163 |
+
arguments)
|
164 |
+
|
165 |
+
super(
|
166 |
+
DifferentiableTask,
|
167 |
+
self).__init__(
|
168 |
+
name,
|
169 |
+
request,
|
170 |
+
examples,
|
171 |
+
features,
|
172 |
+
cache=False)
|
173 |
+
|
174 |
+
def logLikelihood(self, e, timeout=None):
|
175 |
+
assert timeout is None, "timeout not implemented for differentiable tasks, but not for any good reason."
|
176 |
+
e, parameters = PlaceholderVisitor.execute(e)
|
177 |
+
if self.maxParameters is not None and len(
|
178 |
+
parameters) > self.maxParameters:
|
179 |
+
return NEGATIVEINFINITY
|
180 |
+
if self.actualParameters is not None and len(
|
181 |
+
parameters) > self.actualParameters:
|
182 |
+
return NEGATIVEINFINITY
|
183 |
+
f = e.evaluate([])
|
184 |
+
|
185 |
+
loss = sum(self.loss(self.predict(f, xs), y)
|
186 |
+
for xs, y in self.examples) / float(len(self.examples))
|
187 |
+
if isinstance(loss, DN):
|
188 |
+
try:
|
189 |
+
loss = loss.restartingOptimize(
|
190 |
+
parameters,
|
191 |
+
lr=self.specialTask[1]["lr"],
|
192 |
+
steps=self.specialTask[1]["steps"],
|
193 |
+
decay=self.specialTask[1]["decay"],
|
194 |
+
grow=self.specialTask[1]["grow"],
|
195 |
+
attempts=self.specialTask[1]["restarts"],
|
196 |
+
update=None)
|
197 |
+
except InvalidLoss:
|
198 |
+
loss = POSITIVEINFINITY
|
199 |
+
|
200 |
+
# BIC penalty
|
201 |
+
penalty = self.BIC * len(parameters) * math.log(len(self.examples))
|
202 |
+
|
203 |
+
if self.likelihoodThreshold is not None:
|
204 |
+
if loss > -self.likelihoodThreshold:
|
205 |
+
return NEGATIVEINFINITY
|
206 |
+
else:
|
207 |
+
return -penalty
|
208 |
+
else:
|
209 |
+
return -loss / self.temperature - penalty
|
210 |
+
|
211 |
+
|
212 |
+
def squaredErrorLoss(prediction, target):
|
213 |
+
d = prediction - target
|
214 |
+
return d * d
|
215 |
+
|
216 |
+
|
217 |
+
def l1loss(prediction, target):
|
218 |
+
return abs(prediction - target)
|
219 |
+
|
220 |
+
|
221 |
+
class PlaceholderVisitor(object):
|
222 |
+
def __init__(self): self.parameters = []
|
223 |
+
|
224 |
+
def primitive(self, e):
|
225 |
+
if e.name == 'REAL':
|
226 |
+
placeholder = Placeholder.named("REAL_", random.random())
|
227 |
+
self.parameters.append(placeholder)
|
228 |
+
return Primitive(e.name, e.tp, placeholder)
|
229 |
+
return e
|
230 |
+
|
231 |
+
def invented(self, e): return e.body.visit(self)
|
232 |
+
|
233 |
+
def abstraction(self, e): return Abstraction(e.body.visit(self))
|
234 |
+
|
235 |
+
def application(self, e):
|
236 |
+
return Application(e.f.visit(self), e.x.visit(self))
|
237 |
+
|
238 |
+
def index(self, e): return e
|
239 |
+
|
240 |
+
@staticmethod
|
241 |
+
def execute(e):
|
242 |
+
v = PlaceholderVisitor()
|
243 |
+
e = e.visit(v)
|
244 |
+
return e, v.parameters
|
dreamcoder/taskBatcher.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dreamcoder.utilities import eprint
|
2 |
+
import random
|
3 |
+
|
4 |
+
|
5 |
+
class DefaultTaskBatcher:
|
6 |
+
"""Iterates through task batches of the specified size. Defaults to all tasks if taskBatchSize is None."""
|
7 |
+
|
8 |
+
def __init__(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
|
12 |
+
if taskBatchSize is None:
|
13 |
+
taskBatchSize = len(tasks)
|
14 |
+
elif taskBatchSize > len(tasks):
|
15 |
+
eprint("Task batch size is greater than total number of tasks, aborting.")
|
16 |
+
assert False
|
17 |
+
|
18 |
+
|
19 |
+
start = (taskBatchSize * currIteration) % len(tasks)
|
20 |
+
end = start + taskBatchSize
|
21 |
+
taskBatch = (tasks + tasks)[start:end] # Handle wraparound.
|
22 |
+
return taskBatch
|
23 |
+
|
24 |
+
class RandomTaskBatcher:
|
25 |
+
"""Returns a randomly sampled task batch of the specified size. Defaults to all tasks if taskBatchSize is None."""
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
pass
|
29 |
+
|
30 |
+
def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
|
31 |
+
if taskBatchSize is None:
|
32 |
+
taskBatchSize = len(tasks)
|
33 |
+
elif taskBatchSize > len(tasks):
|
34 |
+
eprint("Task batch size is greater than total number of tasks, aborting.")
|
35 |
+
assert False
|
36 |
+
|
37 |
+
return random.sample(tasks, taskBatchSize)
|
38 |
+
|
39 |
+
class RandomShuffleTaskBatcher:
|
40 |
+
"""Randomly shuffles the task batch first, and then iterates through task batches of the specified size like DefaultTaskBatcher.
|
41 |
+
Reshuffles across iterations - intended as benchmark comparison to test the task ordering."""
|
42 |
+
def __init__(self, baseSeed=0): self.baseSeed = baseSeed
|
43 |
+
|
44 |
+
def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
|
45 |
+
if taskBatchSize is None:
|
46 |
+
taskBatchSize = len(tasks)
|
47 |
+
elif taskBatchSize > len(tasks):
|
48 |
+
eprint("Task batch size is greater than total number of tasks, aborting.")
|
49 |
+
assert False
|
50 |
+
|
51 |
+
# Reshuffles tasks in a fixed way across epochs for reproducibility.
|
52 |
+
currEpoch = int(int(currIteration * taskBatchSize) / int(len(tasks)))
|
53 |
+
|
54 |
+
shuffledTasks = tasks.copy() # Since shuffle works in place.
|
55 |
+
random.Random(self.baseSeed + currEpoch).shuffle(shuffledTasks)
|
56 |
+
|
57 |
+
shuffledTasksWrap = tasks.copy() # Since shuffle works in place.
|
58 |
+
random.Random(self.baseSeed + currEpoch + 1).shuffle(shuffledTasksWrap)
|
59 |
+
|
60 |
+
start = (taskBatchSize * currIteration) % len(shuffledTasks)
|
61 |
+
end = start + taskBatchSize
|
62 |
+
taskBatch = (shuffledTasks + shuffledTasksWrap)[start:end] # Wraparound nicely.
|
63 |
+
|
64 |
+
return list(set(taskBatch))
|
65 |
+
|
66 |
+
class UnsolvedTaskBatcher:
|
67 |
+
"""At a given epoch, returns only batches of the tasks that have not been solved at least twice"""
|
68 |
+
|
69 |
+
def __init__(self):
|
70 |
+
self.timesSolved = {} # map from task to times that we have solved it
|
71 |
+
self.start = 0
|
72 |
+
|
73 |
+
def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
|
74 |
+
assert taskBatchSize is None, "This batching strategy does not support batch sizes"
|
75 |
+
|
76 |
+
for t,f in ec_result.allFrontiers.items():
|
77 |
+
if f.empty:
|
78 |
+
self.timesSolved[t] = max(0, self.timesSolved.get(t,0))
|
79 |
+
else:
|
80 |
+
self.timesSolved[t] = 1 + self.timesSolved.get(t, 0)
|
81 |
+
return [t for t in tasks if self.timesSolved.get(t,0) < 2 ]
|
82 |
+
|
83 |
+
def entropyRandomBatch(ec_result, tasks, taskBatchSize, randomRatio):
|
84 |
+
numRandom = int(randomRatio * taskBatchSize)
|
85 |
+
numEntropy = taskBatchSize - numRandom
|
86 |
+
|
87 |
+
eprint("Selecting top %d tasks from the %d overall tasks given lowest entropy." % (taskBatchSize, len(tasks)))
|
88 |
+
eprint("Will be selecting %d by lowest entropy and %d randomly." %(numEntropy, numRandom))
|
89 |
+
taskGrammarEntropies = ec_result.recognitionModel.taskGrammarEntropies(tasks)
|
90 |
+
sortedEntropies = sorted(taskGrammarEntropies.items(), key=lambda x:x[1])
|
91 |
+
|
92 |
+
entropyBatch = [task for (task, entropy) in sortedEntropies[:numEntropy]]
|
93 |
+
randomBatch = random.sample([task for (task, entropy) in sortedEntropies[numEntropy:]], numRandom)
|
94 |
+
batch = entropyBatch + randomBatch
|
95 |
+
|
96 |
+
return batch
|
97 |
+
|
98 |
+
def kNearestNeighbors(ec_result, tasks, k, task):
|
99 |
+
"""Finds the k nearest neighbors in the recognition model logProduction space to a given task."""
|
100 |
+
import numpy as np
|
101 |
+
cosDistance = ec_result.recognitionModel.grammarLogProductionDistanceToTask(task, tasks)
|
102 |
+
argSort = np.argsort(-cosDistance)# Want the greatest similarity.
|
103 |
+
topK = argSort[:k]
|
104 |
+
topKTasks = list(np.array(tasks)[topK])
|
105 |
+
return topKTasks
|
106 |
+
|
107 |
+
|
108 |
+
class RandomkNNTaskBatcher:
|
109 |
+
"""Chooses a random task and finds the (taskBatchSize - 1) nearest neighbors using the recognition model logits."""
|
110 |
+
def __init__(self):
|
111 |
+
pass
|
112 |
+
|
113 |
+
def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
|
114 |
+
if taskBatchSize is None:
|
115 |
+
taskBatchSize = len(tasks)
|
116 |
+
elif taskBatchSize > len(tasks):
|
117 |
+
eprint("Task batch size is greater than total number of tasks, aborting.")
|
118 |
+
assert False
|
119 |
+
|
120 |
+
if ec_result.recognitionModel is None:
|
121 |
+
eprint("No recognition model, falling back on random %d" % taskBatchSize)
|
122 |
+
return random.sample(tasks, taskBatchSize)
|
123 |
+
else:
|
124 |
+
randomTask = random.choice(tasks)
|
125 |
+
kNN = kNearestNeighbors(ec_result, tasks, taskBatchSize - 1, randomTask)
|
126 |
+
return [randomTask] + kNN
|
127 |
+
|
128 |
+
class RandomLowEntropykNNTaskBatcher:
|
129 |
+
"""Choose a random task from the 10 unsolved with the lowest entropy, and finds the (taskBatchSize - 1) nearest neighbors using the recognition model logits."""
|
130 |
+
def __init__(self):
|
131 |
+
pass
|
132 |
+
|
133 |
+
def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
|
134 |
+
unsolvedTasks = [t for t in tasks if ec_result.allFrontiers[t].empty]
|
135 |
+
|
136 |
+
if taskBatchSize is None:
|
137 |
+
return unsolvedTasks
|
138 |
+
elif taskBatchSize > len(tasks):
|
139 |
+
eprint("Task batch size is greater than total number of tasks, aborting.")
|
140 |
+
assert False
|
141 |
+
|
142 |
+
if ec_result.recognitionModel is None:
|
143 |
+
eprint("No recognition model, falling back on random %d tasks from the remaining %d" %(taskBatchSize, len(unsolvedTasks)))
|
144 |
+
return random.sample(unsolvedTasks, taskBatchSize)
|
145 |
+
else:
|
146 |
+
lowEntropyUnsolved = entropyRandomBatch(ec_result, unsolvedTasks, taskBatchSize, randomRatio=0)
|
147 |
+
randomTask = random.choice(lowEntropyUnsolved)
|
148 |
+
kNN = kNearestNeighbors(ec_result, tasks, taskBatchSize - 1, randomTask)
|
149 |
+
return [randomTask] + kNN
|
150 |
+
|
151 |
+
|
152 |
+
class UnsolvedEntropyTaskBatcher:
|
153 |
+
"""Returns tasks that have never been solved at any previous iteration.
|
154 |
+
Given a task batch size, returns the unsolved tasks with the lowest entropy."""
|
155 |
+
def __init__(self):
|
156 |
+
pass
|
157 |
+
|
158 |
+
def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
|
159 |
+
unsolvedTasks = [t for t in tasks if ec_result.allFrontiers[t].empty]
|
160 |
+
|
161 |
+
if taskBatchSize is None:
|
162 |
+
return unsolvedTasks
|
163 |
+
elif taskBatchSize > len(tasks):
|
164 |
+
eprint("Task batch size is greater than total number of tasks, aborting.")
|
165 |
+
assert False
|
166 |
+
|
167 |
+
if ec_result.recognitionModel is None:
|
168 |
+
eprint("No recognition model, falling back on random %d tasks from the remaining %d" %(taskBatchSize, len(unsolvedTasks)))
|
169 |
+
return random.sample(unsolvedTasks, taskBatchSize)
|
170 |
+
else:
|
171 |
+
return entropyRandomBatch(ec_result, unsolvedTasks, taskBatchSize, randomRatio=0)
|
172 |
+
|
173 |
+
class UnsolvedRandomEntropyTaskBatcher:
|
174 |
+
"""Returns tasks that have never been solved at any previous iteration.
|
175 |
+
Given a task batch size, returns a mix of unsolved tasks with percentRandom
|
176 |
+
selected randomly and the remaining selected by lowest entropy."""
|
177 |
+
def __init__(self):
|
178 |
+
pass
|
179 |
+
|
180 |
+
def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
|
181 |
+
unsolvedTasks = [t for t in tasks if ec_result.allFrontiers[t].empty]
|
182 |
+
|
183 |
+
if taskBatchSize is None:
|
184 |
+
return unsolvedTasks
|
185 |
+
elif taskBatchSize > len(tasks):
|
186 |
+
eprint("Task batch size is greater than total number of tasks, aborting.")
|
187 |
+
assert False
|
188 |
+
|
189 |
+
if ec_result.recognitionModel is None:
|
190 |
+
eprint("No recognition model, falling back on random %d tasks from the remaining %d" %(taskBatchSize, len(unsolvedTasks)))
|
191 |
+
return random.sample(unsolvedTasks, taskBatchSize)
|
192 |
+
else:
|
193 |
+
return entropyRandomBatch(ec_result, unsolvedTasks, taskBatchSize, randomRatio=.5)
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
dreamcoder/type.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class UnificationFailure(Exception):
|
2 |
+
pass
|
3 |
+
|
4 |
+
|
5 |
+
class Occurs(UnificationFailure):
|
6 |
+
pass
|
7 |
+
|
8 |
+
|
9 |
+
class Type(object):
|
10 |
+
def __str__(self): return self.show(True)
|
11 |
+
|
12 |
+
def __repr__(self): return str(self)
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
def fromjson(j):
|
16 |
+
if "index" in j: return TypeVariable(j["index"])
|
17 |
+
if "constructor" in j: return TypeConstructor(j["constructor"],
|
18 |
+
[ Type.fromjson(a) for a in j["arguments"] ])
|
19 |
+
assert False
|
20 |
+
|
21 |
+
|
22 |
+
class TypeConstructor(Type):
|
23 |
+
def __init__(self, name, arguments):
|
24 |
+
self.name = name
|
25 |
+
self.arguments = arguments
|
26 |
+
self.isPolymorphic = any(a.isPolymorphic for a in arguments)
|
27 |
+
|
28 |
+
def makeDummyMonomorphic(self, mapping=None):
|
29 |
+
mapping = mapping if mapping is not None else {}
|
30 |
+
return TypeConstructor(self.name,
|
31 |
+
[ a.makeDummyMonomorphic(mapping) for a in self.arguments ])
|
32 |
+
|
33 |
+
def __eq__(self, other):
|
34 |
+
return isinstance(other, TypeConstructor) and \
|
35 |
+
self.name == other.name and \
|
36 |
+
all(x == y for x, y in zip(self.arguments, other.arguments))
|
37 |
+
|
38 |
+
def __hash__(self): return hash((self.name,) + tuple(self.arguments))
|
39 |
+
|
40 |
+
def __ne__(self, other):
|
41 |
+
return not (self == other)
|
42 |
+
|
43 |
+
def show(self, isReturn):
|
44 |
+
if self.name == ARROW:
|
45 |
+
if isReturn:
|
46 |
+
return "%s %s %s" % (self.arguments[0].show(
|
47 |
+
False), ARROW, self.arguments[1].show(True))
|
48 |
+
else:
|
49 |
+
return "(%s %s %s)" % (self.arguments[0].show(
|
50 |
+
False), ARROW, self.arguments[1].show(True))
|
51 |
+
elif self.arguments == []:
|
52 |
+
return self.name
|
53 |
+
else:
|
54 |
+
return "%s(%s)" % (self.name, ", ".join(x.show(True)
|
55 |
+
for x in self.arguments))
|
56 |
+
|
57 |
+
def json(self):
|
58 |
+
return {"constructor": self.name,
|
59 |
+
"arguments": [a.json() for a in self.arguments]}
|
60 |
+
|
61 |
+
|
62 |
+
def isArrow(self): return self.name == ARROW
|
63 |
+
|
64 |
+
def functionArguments(self):
|
65 |
+
if self.name == ARROW:
|
66 |
+
xs = self.arguments[1].functionArguments()
|
67 |
+
return [self.arguments[0]] + xs
|
68 |
+
return []
|
69 |
+
|
70 |
+
def returns(self):
|
71 |
+
if self.name == ARROW:
|
72 |
+
return self.arguments[1].returns()
|
73 |
+
else:
|
74 |
+
return self
|
75 |
+
|
76 |
+
def apply(self, context):
|
77 |
+
if not self.isPolymorphic:
|
78 |
+
return self
|
79 |
+
return TypeConstructor(self.name,
|
80 |
+
[x.apply(context) for x in self.arguments])
|
81 |
+
|
82 |
+
def applyMutable(self, context):
|
83 |
+
if not self.isPolymorphic:
|
84 |
+
return self
|
85 |
+
return TypeConstructor(self.name,
|
86 |
+
[x.applyMutable(context) for x in self.arguments])
|
87 |
+
|
88 |
+
def occurs(self, v):
|
89 |
+
if not self.isPolymorphic:
|
90 |
+
return False
|
91 |
+
return any(x.occurs(v) for x in self.arguments)
|
92 |
+
|
93 |
+
def negateVariables(self):
|
94 |
+
return TypeConstructor(self.name,
|
95 |
+
[a.negateVariables() for a in self.arguments])
|
96 |
+
|
97 |
+
def instantiate(self, context, bindings=None):
|
98 |
+
if not self.isPolymorphic:
|
99 |
+
return context, self
|
100 |
+
if bindings is None:
|
101 |
+
bindings = {}
|
102 |
+
newArguments = []
|
103 |
+
for x in self.arguments:
|
104 |
+
(context, x) = x.instantiate(context, bindings)
|
105 |
+
newArguments.append(x)
|
106 |
+
return (context, TypeConstructor(self.name, newArguments))
|
107 |
+
|
108 |
+
def instantiateMutable(self, context, bindings=None):
|
109 |
+
if not self.isPolymorphic:
|
110 |
+
return self
|
111 |
+
if bindings is None:
|
112 |
+
bindings = {}
|
113 |
+
newArguments = []
|
114 |
+
return TypeConstructor(self.name, [x.instantiateMutable(context, bindings)
|
115 |
+
for x in self.arguments ])
|
116 |
+
|
117 |
+
|
118 |
+
def canonical(self, bindings=None):
|
119 |
+
if not self.isPolymorphic:
|
120 |
+
return self
|
121 |
+
if bindings is None:
|
122 |
+
bindings = {}
|
123 |
+
return TypeConstructor(self.name,
|
124 |
+
[x.canonical(bindings) for x in self.arguments])
|
125 |
+
|
126 |
+
|
127 |
+
class TypeVariable(Type):
|
128 |
+
def __init__(self, j):
|
129 |
+
assert isinstance(j, int)
|
130 |
+
self.v = j
|
131 |
+
self.isPolymorphic = True
|
132 |
+
|
133 |
+
def makeDummyMonomorphic(self, mapping=None):
|
134 |
+
mapping = mapping if mapping is not None else {}
|
135 |
+
if self.v not in mapping:
|
136 |
+
mapping[self.v] = TypeConstructor(f"dummy_type_{len(mapping)}", [])
|
137 |
+
return mapping[self.v]
|
138 |
+
|
139 |
+
|
140 |
+
def __eq__(self, other):
|
141 |
+
return isinstance(other, TypeVariable) and self.v == other.v
|
142 |
+
|
143 |
+
def __ne__(self, other): return not (self.v == other.v)
|
144 |
+
|
145 |
+
def __hash__(self): return self.v
|
146 |
+
|
147 |
+
def show(self, _): return "t%d" % self.v
|
148 |
+
|
149 |
+
def json(self):
|
150 |
+
return {"index": self.v}
|
151 |
+
|
152 |
+
def returns(self): return self
|
153 |
+
|
154 |
+
def isArrow(self): return False
|
155 |
+
|
156 |
+
def functionArguments(self): return []
|
157 |
+
|
158 |
+
def apply(self, context):
|
159 |
+
for v, t in context.substitution:
|
160 |
+
if v == self.v:
|
161 |
+
return t.apply(context)
|
162 |
+
return self
|
163 |
+
|
164 |
+
def applyMutable(self, context):
|
165 |
+
s = context.substitution[self.v]
|
166 |
+
if s is None: return self
|
167 |
+
new = s.applyMutable(context)
|
168 |
+
context.substitution[self.v] = new
|
169 |
+
return new
|
170 |
+
|
171 |
+
def occurs(self, v): return v == self.v
|
172 |
+
|
173 |
+
def instantiate(self, context, bindings=None):
|
174 |
+
if bindings is None:
|
175 |
+
bindings = {}
|
176 |
+
if self.v in bindings:
|
177 |
+
return (context, bindings[self.v])
|
178 |
+
new = TypeVariable(context.nextVariable)
|
179 |
+
bindings[self.v] = new
|
180 |
+
context = Context(context.nextVariable + 1, context.substitution)
|
181 |
+
return (context, new)
|
182 |
+
|
183 |
+
def instantiateMutable(self, context, bindings=None):
|
184 |
+
if bindings is None: bindings = {}
|
185 |
+
if self.v in bindings: return bindings[self.v]
|
186 |
+
new = context.makeVariable()
|
187 |
+
bindings[self.v] = new
|
188 |
+
return new
|
189 |
+
|
190 |
+
def canonical(self, bindings=None):
|
191 |
+
if bindings is None:
|
192 |
+
bindings = {}
|
193 |
+
if self.v in bindings:
|
194 |
+
return bindings[self.v]
|
195 |
+
new = TypeVariable(len(bindings))
|
196 |
+
bindings[self.v] = new
|
197 |
+
return new
|
198 |
+
|
199 |
+
def negateVariables(self):
|
200 |
+
return TypeVariable(-1 - self.v)
|
201 |
+
|
202 |
+
|
203 |
+
class Context(object):
|
204 |
+
def __init__(self, nextVariable=0, substitution=[]):
|
205 |
+
self.nextVariable = nextVariable
|
206 |
+
self.substitution = substitution
|
207 |
+
|
208 |
+
def extend(self, j, t):
|
209 |
+
return Context(self.nextVariable, [(j, t)] + self.substitution)
|
210 |
+
|
211 |
+
def makeVariable(self):
|
212 |
+
return (Context(self.nextVariable + 1, self.substitution),
|
213 |
+
TypeVariable(self.nextVariable))
|
214 |
+
|
215 |
+
def unify(self, t1, t2):
|
216 |
+
t1 = t1.apply(self)
|
217 |
+
t2 = t2.apply(self)
|
218 |
+
if t1 == t2:
|
219 |
+
return self
|
220 |
+
# t1&t2 are not equal
|
221 |
+
if not t1.isPolymorphic and not t2.isPolymorphic:
|
222 |
+
raise UnificationFailure(t1, t2)
|
223 |
+
|
224 |
+
if isinstance(t1, TypeVariable):
|
225 |
+
if t2.occurs(t1.v):
|
226 |
+
raise Occurs()
|
227 |
+
return self.extend(t1.v, t2)
|
228 |
+
if isinstance(t2, TypeVariable):
|
229 |
+
if t1.occurs(t2.v):
|
230 |
+
raise Occurs()
|
231 |
+
return self.extend(t2.v, t1)
|
232 |
+
if t1.name != t2.name:
|
233 |
+
raise UnificationFailure(t1, t2)
|
234 |
+
k = self
|
235 |
+
for x, y in zip(t2.arguments, t1.arguments):
|
236 |
+
k = k.unify(x, y)
|
237 |
+
return k
|
238 |
+
|
239 |
+
def __str__(self):
|
240 |
+
return "Context(next = %d, {%s})" % (self.nextVariable, ", ".join(
|
241 |
+
"t%d ||> %s" % (k, v.apply(self)) for k, v in self.substitution))
|
242 |
+
|
243 |
+
def __repr__(self): return str(self)
|
244 |
+
|
245 |
+
class MutableContext(object):
|
246 |
+
def __init__(self):
|
247 |
+
self.substitution = []
|
248 |
+
|
249 |
+
def extend(self,i,t):
|
250 |
+
assert self.substitution[i] is None
|
251 |
+
self.substitution[i] = t
|
252 |
+
|
253 |
+
def makeVariable(self):
|
254 |
+
self.substitution.append(None)
|
255 |
+
return TypeVariable(len(self.substitution) - 1)
|
256 |
+
|
257 |
+
def unify(self, t1, t2):
|
258 |
+
t1 = t1.applyMutable(self)
|
259 |
+
t2 = t2.applyMutable(self)
|
260 |
+
|
261 |
+
if t1 == t2: return
|
262 |
+
|
263 |
+
# t1&t2 are not equal
|
264 |
+
if not t1.isPolymorphic and not t2.isPolymorphic:
|
265 |
+
raise UnificationFailure(t1, t2)
|
266 |
+
|
267 |
+
if isinstance(t1, TypeVariable):
|
268 |
+
if t2.occurs(t1.v):
|
269 |
+
raise Occurs()
|
270 |
+
self.extend(t1.v, t2)
|
271 |
+
return
|
272 |
+
if isinstance(t2, TypeVariable):
|
273 |
+
if t1.occurs(t2.v):
|
274 |
+
raise Occurs()
|
275 |
+
self.extend(t2.v, t1)
|
276 |
+
return
|
277 |
+
if t1.name != t2.name:
|
278 |
+
raise UnificationFailure(t1, t2)
|
279 |
+
|
280 |
+
for x, y in zip(t2.arguments, t1.arguments):
|
281 |
+
self.unify(x, y)
|
282 |
+
|
283 |
+
|
284 |
+
Context.EMPTY = Context(0, [])
|
285 |
+
|
286 |
+
|
287 |
+
def canonicalTypes(ts):
|
288 |
+
bindings = {}
|
289 |
+
return [t.canonical(bindings) for t in ts]
|
290 |
+
|
291 |
+
|
292 |
+
def instantiateTypes(context, ts):
|
293 |
+
bindings = {}
|
294 |
+
newTypes = []
|
295 |
+
for t in ts:
|
296 |
+
context, t = t.instantiate(context, bindings)
|
297 |
+
newTypes.append(t)
|
298 |
+
return context, newTypes
|
299 |
+
|
300 |
+
|
301 |
+
def baseType(n): return TypeConstructor(n, [])
|
302 |
+
|
303 |
+
|
304 |
+
tint = baseType("int")
|
305 |
+
treal = baseType("real")
|
306 |
+
tbool = baseType("bool")
|
307 |
+
tboolean = tbool # alias
|
308 |
+
tcharacter = baseType("char")
|
309 |
+
|
310 |
+
|
311 |
+
def tlist(t): return TypeConstructor("list", [t])
|
312 |
+
|
313 |
+
|
314 |
+
def tpair(a, b): return TypeConstructor("pair", [a, b])
|
315 |
+
|
316 |
+
|
317 |
+
def tmaybe(t): return TypeConstructor("maybe", [t])
|
318 |
+
|
319 |
+
|
320 |
+
tstr = tlist(tcharacter)
|
321 |
+
t0 = TypeVariable(0)
|
322 |
+
t1 = TypeVariable(1)
|
323 |
+
t2 = TypeVariable(2)
|
324 |
+
|
325 |
+
# regex types
|
326 |
+
tpregex = baseType("pregex")
|
327 |
+
|
328 |
+
ARROW = "->"
|
329 |
+
|
330 |
+
|
331 |
+
def arrow(*arguments):
|
332 |
+
if len(arguments) == 1:
|
333 |
+
return arguments[0]
|
334 |
+
return TypeConstructor(ARROW, [arguments[0], arrow(*arguments[1:])])
|
335 |
+
|
336 |
+
|
337 |
+
def inferArg(tp, tcaller):
|
338 |
+
ctx, tp = tp.instantiate(Context.EMPTY)
|
339 |
+
ctx, tcaller = tcaller.instantiate(ctx)
|
340 |
+
ctx, targ = ctx.makeVariable()
|
341 |
+
ctx = ctx.unify(tcaller, arrow(targ, tp))
|
342 |
+
return targ.apply(ctx)
|
343 |
+
|
344 |
+
|
345 |
+
def guess_type(xs):
|
346 |
+
"""
|
347 |
+
Return a TypeConstructor corresponding to x's python type.
|
348 |
+
Raises an exception if the type cannot be guessed.
|
349 |
+
"""
|
350 |
+
if all(isinstance(x, bool) for x in xs):
|
351 |
+
return tbool
|
352 |
+
elif all(isinstance(x, int) for x in xs):
|
353 |
+
return tint
|
354 |
+
elif all(isinstance(x, str) for x in xs):
|
355 |
+
return tstr
|
356 |
+
elif all(isinstance(x, list) for x in xs):
|
357 |
+
return tlist(guess_type([y for ys in xs for y in ys]))
|
358 |
+
else:
|
359 |
+
raise ValueError("cannot guess type from {}".format(xs))
|
360 |
+
|
361 |
+
|
362 |
+
def guess_arrow_type(examples):
|
363 |
+
a = len(examples[0][0])
|
364 |
+
input_types = []
|
365 |
+
for n in range(a):
|
366 |
+
input_types.append(guess_type([xs[n] for xs, _ in examples]))
|
367 |
+
output_type = guess_type([y for _, y in examples])
|
368 |
+
return arrow(*(input_types + [output_type]))
|
369 |
+
|
370 |
+
def canUnify(t1, t2):
|
371 |
+
k = MutableContext()
|
372 |
+
t1 = t1.instantiateMutable(k)
|
373 |
+
t2 = t2.instantiateMutable(k)
|
374 |
+
try:
|
375 |
+
k.unify(t1, t2)
|
376 |
+
return True
|
377 |
+
except UnificationFailure: return False
|
378 |
+
|