Fraser-Greenlee commited on
Commit
e1c1753
1 Parent(s): 4c34db7

add dreamcoder codebase

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. dreamcoder/__init__.py +107 -0
  2. dreamcoder/compression.py +282 -0
  3. dreamcoder/deprecated/__init__.py +0 -0
  4. dreamcoder/deprecated/network.py +479 -0
  5. dreamcoder/differentiation.py +393 -0
  6. dreamcoder/domains/__init__.py +0 -0
  7. dreamcoder/domains/arithmetic/__init__.py +0 -0
  8. dreamcoder/domains/arithmetic/arithmeticPrimitives.py +58 -0
  9. dreamcoder/domains/list/__init__.py +0 -0
  10. dreamcoder/domains/list/listPrimitives.py +546 -0
  11. dreamcoder/domains/list/main.py +410 -0
  12. dreamcoder/domains/list/makeListTasks.py +587 -0
  13. dreamcoder/domains/logo/__init__.py +0 -0
  14. dreamcoder/domains/logo/logoPrimitives.py +41 -0
  15. dreamcoder/domains/logo/main.py +450 -0
  16. dreamcoder/domains/logo/makeLogoTasks.py +777 -0
  17. dreamcoder/domains/misc/RobustFillPrimitives.py +308 -0
  18. dreamcoder/domains/misc/__init__.py +0 -0
  19. dreamcoder/domains/misc/algolispPrimitives.py +508 -0
  20. dreamcoder/domains/misc/deepcoderPrimitives.py +352 -0
  21. dreamcoder/domains/misc/napsPrimitives.py +198 -0
  22. dreamcoder/domains/regex/__init__.py +0 -0
  23. dreamcoder/domains/regex/groundtruthRegexes.py +172 -0
  24. dreamcoder/domains/regex/main.py +384 -0
  25. dreamcoder/domains/regex/makeRegexTasks.py +347 -0
  26. dreamcoder/domains/regex/regexPrimitives.py +367 -0
  27. dreamcoder/domains/text/__init__.py +0 -0
  28. dreamcoder/domains/text/main.py +270 -0
  29. dreamcoder/domains/text/makeTextTasks.py +424 -0
  30. dreamcoder/domains/text/textPrimitives.py +87 -0
  31. dreamcoder/domains/tower/__init__.py +0 -0
  32. dreamcoder/domains/tower/main.py +359 -0
  33. dreamcoder/domains/tower/makeTowerTasks.py +556 -0
  34. dreamcoder/domains/tower/towerPrimitives.py +152 -0
  35. dreamcoder/domains/tower/tower_common.py +173 -0
  36. dreamcoder/dreamcoder.py +1074 -0
  37. dreamcoder/dreaming.py +90 -0
  38. dreamcoder/ec.py +3 -0
  39. dreamcoder/enumeration.py +469 -0
  40. dreamcoder/fragmentGrammar.py +430 -0
  41. dreamcoder/fragmentUtilities.py +405 -0
  42. dreamcoder/frontier.py +247 -0
  43. dreamcoder/grammar.py +1308 -0
  44. dreamcoder/likelihoodModel.py +407 -0
  45. dreamcoder/primitiveGraph.py +182 -0
  46. dreamcoder/program.py +1214 -0
  47. dreamcoder/recognition.py +1528 -0
  48. dreamcoder/task.py +244 -0
  49. dreamcoder/taskBatcher.py +200 -0
  50. dreamcoder/type.py +378 -0
dreamcoder/__init__.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EC codebase Python library (AKA the "frontend")
3
+
4
+ Module mapping details:
5
+
6
+ TODO: remove module mapping code when backwards-compatibility is no longer required.
7
+
8
+ The below module mapping is required for backwards-compatibility with old pickle files
9
+ generated from before the EC codebase refactor. New files added to the codebase do not
10
+ need to be added to the mapping, but if the existing modules are moved, then this the
11
+ mapping needs to be updated to reflect the move or rename.
12
+
13
+ The mapping uses the following pattern:
14
+
15
+ sys.modules[<old module path>] = <new module reference>
16
+
17
+ This is because the previous structure of the codebase was completely flat, and when refactoring
18
+ to a hierarchical files, loading previous pickle files no longer works properly. It is important
19
+ to retain the ability to read old pickle files generated from official experiments. As a workaround,
20
+ the old module paths are included below. A preferable alternative would be to export program state
21
+ into JSON files instead of pickle files to avoid issues where the underlying classes change, so that
22
+ could be a future improvement to this project. Until then, we use the module mapping workaround.
23
+
24
+ For more info, see this StackOverflow answer: https://stackoverflow.com/a/2121918/2573242
25
+ """
26
+ import sys
27
+
28
+ from dreamcoder import differentiation
29
+ from dreamcoder import dreamcoder
30
+ from dreamcoder import enumeration
31
+ from dreamcoder import fragmentGrammar
32
+ from dreamcoder import fragmentUtilities
33
+ from dreamcoder import frontier
34
+ from dreamcoder import grammar
35
+ from dreamcoder import likelihoodModel
36
+ from dreamcoder import program
37
+ from dreamcoder import primitiveGraph
38
+ try:
39
+ from dreamcoder import recognition
40
+ except:
41
+ print("Failure loading recognition - only acceptable if using pypy ",file=sys.stderr)
42
+ from dreamcoder import task
43
+ from dreamcoder import taskBatcher
44
+ from dreamcoder import type
45
+ from dreamcoder import utilities
46
+ from dreamcoder import vs
47
+ from dreamcoder.domains.misc import algolispPrimitives, deepcoderPrimitives
48
+ from dreamcoder.domains.misc import RobustFillPrimitives
49
+ from dreamcoder.domains.misc import napsPrimitives
50
+ from dreamcoder.domains.tower import makeTowerTasks
51
+ from dreamcoder.domains.tower import towerPrimitives
52
+ from dreamcoder.domains.tower import tower_common
53
+ from dreamcoder.domains.tower import main as tower_main
54
+ from dreamcoder.domains.regex import groundtruthRegexes
55
+ from dreamcoder.domains.regex import regexPrimitives
56
+ from dreamcoder.domains.regex import makeRegexTasks
57
+ #from dreamcoder.domains.regex import main as regex_main
58
+ from dreamcoder.domains.logo import logoPrimitives
59
+ from dreamcoder.domains.logo import makeLogoTasks
60
+ from dreamcoder.domains.logo import main as logo_main
61
+ from dreamcoder.domains.list import listPrimitives
62
+ from dreamcoder.domains.list import makeListTasks
63
+ from dreamcoder.domains.list import main as list_main
64
+ from dreamcoder.domains.arithmetic import arithmeticPrimitives
65
+ from dreamcoder.domains.text import textPrimitives
66
+ from dreamcoder.domains.text import makeTextTasks
67
+ from dreamcoder.domains.text import main as text_main
68
+
69
+ sys.modules['differentiation'] = differentiation
70
+ sys.modules['ec'] = dreamcoder
71
+ sys.modules['enumeration'] = enumeration
72
+ sys.modules['fragmentGrammar'] = fragmentGrammar
73
+ sys.modules['fragmentUtilities'] = fragmentUtilities
74
+ sys.modules['frontier'] = frontier
75
+ sys.modules['grammar'] = grammar
76
+ sys.modules['likelihoodModel'] = likelihoodModel
77
+ sys.modules['program'] = program
78
+ try: sys.modules['recognition'] = recognition
79
+ except: pass
80
+ sys.modules['task'] = task
81
+ sys.modules['taskBatcher'] = taskBatcher
82
+ sys.modules['type'] = type
83
+ sys.modules['utilities'] = utilities
84
+ sys.modules['vs'] = vs
85
+ sys.modules['algolispPrimitives'] = algolispPrimitives
86
+ sys.modules['RobustFillPrimitives'] = RobustFillPrimitives
87
+ sys.modules['napsPrimitives'] = napsPrimitives
88
+ sys.modules['makeTowerTasks'] = makeTowerTasks
89
+ sys.modules['towerPrimitives'] = towerPrimitives
90
+ sys.modules['tower_common'] = tower_common
91
+ #sys.modules['tower'] = tower_main
92
+ sys.modules['groundtruthRegexes'] = groundtruthRegexes
93
+ sys.modules['regexPrimitives'] = regexPrimitives
94
+ sys.modules['makeRegexTasks'] = makeRegexTasks
95
+ #sys.modules['regexes'] = regex_main
96
+ sys.modules['deepcoderPrimitives'] = deepcoderPrimitives
97
+ sys.modules['logoPrimitives'] = logoPrimitives
98
+ sys.modules['makeLogoTasks'] = makeLogoTasks
99
+ #sys.modules['logo'] = logo_main
100
+ sys.modules['listPrimitives'] = listPrimitives
101
+ sys.modules['makeListTasks'] = makeListTasks
102
+ #sys.modules['list'] = list_main
103
+ sys.modules['arithmeticPrimitives'] = arithmeticPrimitives
104
+ sys.modules['textPrimitives'] = textPrimitives
105
+ sys.modules['makeTextTasks'] = makeTextTasks
106
+ #sys.modules['text'] = text_main
107
+ sys.modules['primitiveGraph'] = primitiveGraph
dreamcoder/compression.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import os
4
+ import pickle
5
+ import subprocess
6
+ import sys
7
+
8
+ from dreamcoder.fragmentGrammar import FragmentGrammar
9
+ from dreamcoder.frontier import Frontier, FrontierEntry
10
+ from dreamcoder.grammar import Grammar
11
+ from dreamcoder.task import Task
12
+ from dreamcoder.program import Program, Invented
13
+ from dreamcoder.utilities import eprint, timing, callCompiled, get_root_dir
14
+ from dreamcoder.vs import induceGrammar_Beta
15
+
16
+
17
+ def induceGrammar(*args, **kwargs):
18
+ if sum(not f.empty for f in args[1]) == 0:
19
+ eprint("No nonempty frontiers, exiting grammar induction early.")
20
+ return args[0], args[1]
21
+ backend = kwargs.pop("backend", "pypy")
22
+ if 'pypy' in backend:
23
+ # pypy might not like some of the imports needed for the primitives
24
+ # but the primitive values are irrelevant for compression
25
+ # therefore strip them out and then replace them once we are done
26
+ # ditto for task data
27
+ g0,frontiers = args[0].strip_primitive_values(), \
28
+ [front.strip_primitive_values() for front in args[1]]
29
+ original_tasks = {f.task.name: f.task for f in frontiers}
30
+ frontiers = [Frontier(f.entries, Task(f.task.name,f.task.request,[]))
31
+ for f in frontiers ]
32
+ args = [g0,frontiers]
33
+
34
+
35
+ with timing("Induced a grammar"):
36
+ if backend == "pypy":
37
+ g, newFrontiers = callCompiled(pypyInduce, *args, **kwargs)
38
+ elif backend == "rust":
39
+ g, newFrontiers = rustInduce(*args, **kwargs)
40
+ elif backend == "vs":
41
+ g, newFrontiers = rustInduce(*args, vs=True, **kwargs)
42
+ elif backend == "pypy_vs":
43
+ kwargs.pop('iteration')
44
+ kwargs.pop('topk_use_only_likelihood')
45
+ fn = '/tmp/vs.pickle'
46
+ with open(fn, 'wb') as handle:
47
+ pickle.dump((args, kwargs), handle)
48
+ eprint("For debugging purposes, the version space compression invocation has been saved to", fn)
49
+ g, newFrontiers = callCompiled(induceGrammar_Beta, *args, **kwargs)
50
+ elif backend == "ocaml":
51
+ kwargs.pop('iteration')
52
+ kwargs.pop('topk_use_only_likelihood')
53
+ kwargs['topI'] = 300
54
+ kwargs['bs'] = 1000000
55
+ g, newFrontiers = ocamlInduce(*args, **kwargs)
56
+ elif backend == "memorize":
57
+ g, newFrontiers = memorizeInduce(*args, **kwargs)
58
+ else:
59
+ assert False, "unknown compressor"
60
+
61
+ if 'pypy' in backend:
62
+ g, newFrontiers = g.unstrip_primitive_values(), \
63
+ [front.unstrip_primitive_values() for front in newFrontiers]
64
+ newFrontiers = [Frontier(f.entries, original_tasks[f.task.name])
65
+ for f in newFrontiers]
66
+
67
+
68
+ return g, newFrontiers
69
+
70
+ def memorizeInduce(g, frontiers, **kwargs):
71
+ existingInventions = {p.uncurry()
72
+ for p in g.primitives }
73
+ programs = {f.bestPosterior.program for f in frontiers if not f.empty}
74
+ newInventions = programs - existingInventions
75
+ newGrammar = Grammar.uniform([p for p in g.primitives] + \
76
+ [Invented(ni) for ni in newInventions])
77
+
78
+ # rewrite in terms of new primitives
79
+ def substitute(p):
80
+ nonlocal newInventions
81
+ if p in newInventions: return Invented(p).uncurry()
82
+ return p
83
+ newFrontiers = [Frontier([FrontierEntry(program=np,
84
+ logPrior=newGrammar.logLikelihood(f.task.request, np),
85
+ logLikelihood=e.logLikelihood)
86
+ for e in f
87
+ for np in [substitute(e.program)] ],
88
+ task=f.task)
89
+ for f in frontiers ]
90
+ return newGrammar, newFrontiers
91
+
92
+
93
+
94
+
95
+
96
+ def pypyInduce(*args, **kwargs):
97
+ kwargs.pop('iteration')
98
+ return FragmentGrammar.induceFromFrontiers(*args, **kwargs)
99
+
100
+
101
+ def ocamlInduce(g, frontiers, _=None,
102
+ topK=1, pseudoCounts=1.0, aic=1.0,
103
+ structurePenalty=0.001, a=0, CPUs=1,
104
+ bs=1000000, topI=300):
105
+ # This is a dirty hack!
106
+ # Memory consumption increases with the number of CPUs
107
+ # And early on we have a lot of stuff to compress
108
+ # If this is the first iteration, only use a fraction of the available CPUs
109
+ if all(not p.isInvented for p in g.primitives):
110
+ if a > 3:
111
+ CPUs = max(1, int(CPUs / 6))
112
+ else:
113
+ CPUs = max(1, int(CPUs / 3))
114
+ else:
115
+ CPUs = max(1, int(CPUs / 2))
116
+ CPUs = 2
117
+
118
+ # X X X FIXME X X X
119
+ # for unknown reasons doing compression all in one go works correctly and doing it with Python and the outer loop causes problems
120
+ iterations = 99 # maximum number of components to add at once
121
+
122
+ while True:
123
+ g0 = g
124
+
125
+ originalFrontiers = frontiers
126
+ t2f = {f.task: f for f in frontiers}
127
+ frontiers = [f for f in frontiers if not f.empty]
128
+ message = {"arity": a,
129
+ "topK": topK,
130
+ "pseudoCounts": float(pseudoCounts),
131
+ "aic": aic,
132
+ "bs": bs,
133
+ "topI": topI,
134
+ "structurePenalty": float(structurePenalty),
135
+ "CPUs": CPUs,
136
+ "DSL": g.json(),
137
+ "iterations": iterations,
138
+ "frontiers": [f.json()
139
+ for f in frontiers]}
140
+
141
+ message = json.dumps(message)
142
+ if True:
143
+ timestamp = datetime.datetime.now().isoformat()
144
+ os.system("mkdir -p compressionMessages")
145
+ fn = "compressionMessages/%s" % timestamp
146
+ with open(fn, "w") as f:
147
+ f.write(message)
148
+ eprint("Compression message saved to:", fn)
149
+
150
+ try:
151
+ # Get relative path
152
+ compressor_file = os.path.join(get_root_dir(), 'compression')
153
+ process = subprocess.Popen(compressor_file,
154
+ stdin=subprocess.PIPE,
155
+ stdout=subprocess.PIPE)
156
+ response, error = process.communicate(bytes(message, encoding="utf-8"))
157
+ response = json.loads(response.decode("utf-8"))
158
+ except OSError as exc:
159
+ raise exc
160
+
161
+ g = response["DSL"]
162
+ g = Grammar(g["logVariable"],
163
+ [(l, p.infer(), p)
164
+ for production in g["productions"]
165
+ for l in [production["logProbability"]]
166
+ for p in [Program.parse(production["expression"])]],
167
+ continuationType=g0.continuationType)
168
+
169
+ frontiers = {original.task:
170
+ Frontier([FrontierEntry(p,
171
+ logLikelihood=e["logLikelihood"],
172
+ logPrior=g.logLikelihood(original.task.request, p))
173
+ for e in new["programs"]
174
+ for p in [Program.parse(e["program"])]],
175
+ task=original.task)
176
+ for original, new in zip(frontiers, response["frontiers"])}
177
+ frontiers = [frontiers.get(f.task, t2f[f.task])
178
+ for f in originalFrontiers]
179
+ if iterations == 1 and len(g) > len(g0):
180
+ eprint("Grammar changed - running another round of consolidation.")
181
+ continue
182
+ else:
183
+ eprint("Finished consolidation.")
184
+ return g, frontiers
185
+
186
+
187
+ def rustInduce(g0, frontiers, _=None,
188
+ topK=1, pseudoCounts=1.0, aic=1.0,
189
+ structurePenalty=0.001, a=0, CPUs=1, iteration=-1,
190
+ topk_use_only_likelihood=False,
191
+ vs=False):
192
+ def finite_logp(l):
193
+ return l if l != float("-inf") else -1000
194
+
195
+ message = {
196
+ "strategy": {"version-spaces": {"top_i": 50}}
197
+ if vs else
198
+ {"fragment-grammars": {}},
199
+ "params": {
200
+ "structure_penalty": structurePenalty,
201
+ "pseudocounts": int(pseudoCounts + 0.5),
202
+ "topk": topK,
203
+ "topk_use_only_likelihood": topk_use_only_likelihood,
204
+ "aic": aic if aic != float("inf") else None,
205
+ "arity": a,
206
+ },
207
+ "primitives": [{"name": p.name, "tp": str(t), "logp": finite_logp(l)}
208
+ for l, t, p in g0.productions if p.isPrimitive],
209
+ "inventions": [{"expression": str(p.body),
210
+ "logp": finite_logp(l)} # -inf=-100
211
+ for l, t, p in g0.productions if p.isInvented],
212
+ "variable_logprob": finite_logp(g0.logVariable),
213
+ "frontiers": [{
214
+ "task_tp": str(f.task.request),
215
+ "solutions": [{
216
+ "expression": str(e.program),
217
+ "logprior": finite_logp(e.logPrior),
218
+ "loglikelihood": e.logLikelihood,
219
+ } for e in f],
220
+ } for f in frontiers],
221
+ }
222
+
223
+ eprint("running rust compressor")
224
+
225
+ messageJson = json.dumps(message)
226
+
227
+ with open("jsonDebug", "w") as f:
228
+ f.write(messageJson)
229
+
230
+ # check which version of python we are using
231
+ # if >=3.6 do:
232
+ if sys.version_info[1] >= 6:
233
+ p = subprocess.Popen(
234
+ ['./rust_compressor/rust_compressor'],
235
+ encoding='utf-8',
236
+ stdin=subprocess.PIPE,
237
+ stdout=subprocess.PIPE)
238
+ elif sys.version_info[1] == 5:
239
+ p = subprocess.Popen(
240
+ ['./rust_compressor/rust_compressor'],
241
+ stdin=subprocess.PIPE,
242
+ stdout=subprocess.PIPE)
243
+
244
+ messageJson = bytearray(messageJson, encoding='utf-8')
245
+ # convert messageJson string to bytes
246
+ else:
247
+ eprint("must be python 3.5 or 3.6")
248
+ assert False
249
+
250
+ p.stdin.write(messageJson)
251
+ p.stdin.flush()
252
+ p.stdin.close()
253
+
254
+ if p.returncode is not None:
255
+ raise ValueError("rust compressor failed")
256
+
257
+ if sys.version_info[1] >= 6:
258
+ resp = json.load(p.stdout)
259
+ elif sys.version_info[1] == 5:
260
+ import codecs
261
+ resp = json.load(codecs.getreader('utf-8')(p.stdout))
262
+
263
+ productions = [(x["logp"], p) for p, x in
264
+ zip((p for (_, _, p) in g0.productions if p.isPrimitive), resp["primitives"])] + \
265
+ [(i["logp"], Invented(Program.parse(i["expression"])))
266
+ for i in resp["inventions"]]
267
+ productions = [(l if l is not None else float("-inf"), p)
268
+ for l, p in productions]
269
+ g = Grammar.fromProductions(productions, resp["variable_logprob"], continuationType=g0.continuationType)
270
+ newFrontiers = [
271
+ Frontier(
272
+ [
273
+ FrontierEntry(
274
+ Program.parse(
275
+ s["expression"]),
276
+ logPrior=s["logprior"],
277
+ logLikelihood=s["loglikelihood"]) for s in r["solutions"]],
278
+ f.task) for f,
279
+ r in zip(
280
+ frontiers,
281
+ resp["frontiers"])]
282
+ return g, newFrontiers
dreamcoder/deprecated/__init__.py ADDED
File without changes
dreamcoder/deprecated/network.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Deprecated network.py module. This file only exists to support backwards-compatibility
3
+ with old pickle files. See lib/__init__.py for more information.
4
+ """
5
+
6
+ from __future__ import print_function
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.autograd import Variable
12
+ from torch.nn.parameter import Parameter
13
+
14
+
15
+ # UPGRADING TO INPUT -> OUTPUT -> TARGET
16
+ # Todo:
17
+ # [X] Output attending to input
18
+ # [X] Target attending to output
19
+ # [ ] check passing hidden state between encoders/decoder (+ pass c?)
20
+ # [ ] add v_output
21
+
22
+
23
+ def choose(matrix, idxs):
24
+ if isinstance(idxs, Variable):
25
+ idxs = idxs.data
26
+ assert(matrix.ndimension() == 2)
27
+ unrolled_idxs = idxs + \
28
+ torch.arange(0, matrix.size(0)).type_as(idxs) * matrix.size(1)
29
+ return matrix.view(matrix.nelement())[unrolled_idxs]
30
+
31
+
32
+ class Network(nn.Module):
33
+ """
34
+ Todo:
35
+ - Beam search
36
+ - check if this is right? attend during P->FC rather than during softmax->P?
37
+ - allow length 0 inputs/targets
38
+ - give n_examples as input to FC
39
+ - Initialise new weights randomly, rather than as zeroes
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ input_vocabulary,
45
+ target_vocabulary,
46
+ hidden_size=512,
47
+ embedding_size=128,
48
+ cell_type="LSTM"):
49
+ """
50
+ :param list input_vocabulary: list of possible inputs
51
+ :param list target_vocabulary: list of possible targets
52
+ """
53
+ super(Network, self).__init__()
54
+ self.h_input_encoder_size = hidden_size
55
+ self.h_output_encoder_size = hidden_size
56
+ self.h_decoder_size = hidden_size
57
+ self.embedding_size = embedding_size
58
+ self.input_vocabulary = input_vocabulary
59
+ self.target_vocabulary = target_vocabulary
60
+ # Number of tokens in input vocabulary
61
+ self.v_input = len(input_vocabulary)
62
+ # Number of tokens in target vocabulary
63
+ self.v_target = len(target_vocabulary)
64
+
65
+ self.cell_type = cell_type
66
+ if cell_type == 'GRU':
67
+ self.input_encoder_cell = nn.GRUCell(
68
+ input_size=self.v_input + 1,
69
+ hidden_size=self.h_input_encoder_size,
70
+ bias=True)
71
+ self.input_encoder_init = Parameter(
72
+ torch.rand(1, self.h_input_encoder_size))
73
+ self.output_encoder_cell = nn.GRUCell(
74
+ input_size=self.v_input +
75
+ 1 +
76
+ self.h_input_encoder_size,
77
+ hidden_size=self.h_output_encoder_size,
78
+ bias=True)
79
+ self.decoder_cell = nn.GRUCell(
80
+ input_size=self.v_target + 1,
81
+ hidden_size=self.h_decoder_size,
82
+ bias=True)
83
+ if cell_type == 'LSTM':
84
+ self.input_encoder_cell = nn.LSTMCell(
85
+ input_size=self.v_input + 1,
86
+ hidden_size=self.h_input_encoder_size,
87
+ bias=True)
88
+ self.input_encoder_init = nn.ParameterList([Parameter(torch.rand(
89
+ 1, self.h_input_encoder_size)), Parameter(torch.rand(1, self.h_input_encoder_size))])
90
+ self.output_encoder_cell = nn.LSTMCell(
91
+ input_size=self.v_input +
92
+ 1 +
93
+ self.h_input_encoder_size,
94
+ hidden_size=self.h_output_encoder_size,
95
+ bias=True)
96
+ self.output_encoder_init_c = Parameter(
97
+ torch.rand(1, self.h_output_encoder_size))
98
+ self.decoder_cell = nn.LSTMCell(
99
+ input_size=self.v_target + 1,
100
+ hidden_size=self.h_decoder_size,
101
+ bias=True)
102
+ self.decoder_init_c = Parameter(torch.rand(1, self.h_decoder_size))
103
+
104
+ self.W = nn.Linear(
105
+ self.h_output_encoder_size +
106
+ self.h_decoder_size,
107
+ self.embedding_size)
108
+ self.V = nn.Linear(self.embedding_size, self.v_target + 1)
109
+ self.input_A = nn.Bilinear(
110
+ self.h_input_encoder_size,
111
+ self.h_output_encoder_size,
112
+ 1,
113
+ bias=False)
114
+ self.output_A = nn.Bilinear(
115
+ self.h_output_encoder_size,
116
+ self.h_decoder_size,
117
+ 1,
118
+ bias=False)
119
+ self.input_EOS = torch.zeros(1, self.v_input + 1)
120
+ self.input_EOS[:, -1] = 1
121
+ self.input_EOS = Parameter(self.input_EOS)
122
+ self.output_EOS = torch.zeros(1, self.v_input + 1)
123
+ self.output_EOS[:, -1] = 1
124
+ self.output_EOS = Parameter(self.output_EOS)
125
+ self.target_EOS = torch.zeros(1, self.v_target + 1)
126
+ self.target_EOS[:, -1] = 1
127
+ self.target_EOS = Parameter(self.target_EOS)
128
+
129
+ def __getstate__(self):
130
+ if hasattr(self, 'opt'):
131
+ return dict([(k, v) for k, v in self.__dict__.items(
132
+ ) if k is not 'opt'] + [('optstate', self.opt.state_dict())])
133
+ # return {**{k:v for k,v in self.__dict__.items() if k is not 'opt'},
134
+ # 'optstate': self.opt.state_dict()}
135
+ else:
136
+ return self.__dict__
137
+
138
+ def __setstate__(self, state):
139
+ self.__dict__.update(state)
140
+ # Legacy:
141
+ if isinstance(self.input_encoder_init, tuple):
142
+ self.input_encoder_init = nn.ParameterList(
143
+ list(self.input_encoder_init))
144
+
145
+ def clear_optimiser(self):
146
+ if hasattr(self, 'opt'):
147
+ del self.opt
148
+ if hasattr(self, 'optstate'):
149
+ del self.optstate
150
+
151
+ def get_optimiser(self):
152
+ self.opt = torch.optim.Adam(self.parameters(), lr=0.001)
153
+ if hasattr(self, 'optstate'):
154
+ self.opt.load_state_dict(self.optstate)
155
+
156
+ def optimiser_step(self, inputs, outputs, target):
157
+ if not hasattr(self, 'opt'):
158
+ self.get_optimiser()
159
+ score = self.score(inputs, outputs, target, autograd=True).mean()
160
+ (-score).backward()
161
+ self.opt.step()
162
+ self.opt.zero_grad()
163
+ return score.data[0]
164
+
165
+ def set_target_vocabulary(self, target_vocabulary):
166
+ if target_vocabulary == self.target_vocabulary:
167
+ return
168
+
169
+ V_weight = []
170
+ V_bias = []
171
+ decoder_ih = []
172
+
173
+ for i in range(len(target_vocabulary)):
174
+ if target_vocabulary[i] in self.target_vocabulary:
175
+ j = self.target_vocabulary.index(target_vocabulary[i])
176
+ V_weight.append(self.V.weight.data[j:j + 1])
177
+ V_bias.append(self.V.bias.data[j:j + 1])
178
+ decoder_ih.append(self.decoder_cell.weight_ih.data[:, j:j + 1])
179
+ else:
180
+ V_weight.append(torch.zeros(1, self.V.weight.size(1)))
181
+ V_bias.append(torch.ones(1) * -10)
182
+ decoder_ih.append(
183
+ torch.zeros(
184
+ self.decoder_cell.weight_ih.data.size(0), 1))
185
+
186
+ V_weight.append(self.V.weight.data[-1:])
187
+ V_bias.append(self.V.bias.data[-1:])
188
+ decoder_ih.append(self.decoder_cell.weight_ih.data[:, -1:])
189
+
190
+ self.target_vocabulary = target_vocabulary
191
+ self.v_target = len(target_vocabulary)
192
+ self.target_EOS.data = torch.zeros(1, self.v_target + 1)
193
+ self.target_EOS.data[:, -1] = 1
194
+
195
+ self.V.weight.data = torch.cat(V_weight, dim=0)
196
+ self.V.bias.data = torch.cat(V_bias, dim=0)
197
+ self.V.out_features = self.V.bias.data.size(0)
198
+
199
+ self.decoder_cell.weight_ih.data = torch.cat(decoder_ih, dim=1)
200
+ self.decoder_cell.input_size = self.decoder_cell.weight_ih.data.size(1)
201
+
202
+ self.clear_optimiser()
203
+
204
+ def input_encoder_get_init(self, batch_size):
205
+ if self.cell_type == "GRU":
206
+ return self.input_encoder_init.repeat(batch_size, 1)
207
+ if self.cell_type == "LSTM":
208
+ return tuple(x.repeat(batch_size, 1)
209
+ for x in self.input_encoder_init)
210
+
211
+ def output_encoder_get_init(self, input_encoder_h):
212
+ if self.cell_type == "GRU":
213
+ return input_encoder_h
214
+ if self.cell_type == "LSTM":
215
+ return (
216
+ input_encoder_h,
217
+ self.output_encoder_init_c.repeat(
218
+ input_encoder_h.size(0),
219
+ 1))
220
+
221
+ def decoder_get_init(self, output_encoder_h):
222
+ if self.cell_type == "GRU":
223
+ return output_encoder_h
224
+ if self.cell_type == "LSTM":
225
+ return (
226
+ output_encoder_h,
227
+ self.decoder_init_c.repeat(
228
+ output_encoder_h.size(0),
229
+ 1))
230
+
231
+ def cell_get_h(self, cell_state):
232
+ if self.cell_type == "GRU":
233
+ return cell_state
234
+ if self.cell_type == "LSTM":
235
+ return cell_state[0]
236
+
237
+ def score(self, inputs, outputs, target, autograd=False):
238
+ inputs = self.inputsToTensors(inputs)
239
+ outputs = self.inputsToTensors(outputs)
240
+ target = self.targetToTensor(target)
241
+ target, score = self.run(inputs, outputs, target=target, mode="score")
242
+ # target = self.tensorToOutput(target)
243
+ if autograd:
244
+ return score
245
+ else:
246
+ return score.data
247
+
248
+ def sample(self, inputs, outputs):
249
+ inputs = self.inputsToTensors(inputs)
250
+ outputs = self.inputsToTensors(outputs)
251
+ target, score = self.run(inputs, outputs, mode="sample")
252
+ target = self.tensorToOutput(target)
253
+ return target
254
+
255
+ def sampleAndScore(self, inputs, outputs, nRepeats=None):
256
+ inputs = self.inputsToTensors(inputs)
257
+ outputs = self.inputsToTensors(outputs)
258
+ if nRepeats is None:
259
+ target, score = self.run(inputs, outputs, mode="sample")
260
+ target = self.tensorToOutput(target)
261
+ return target, score.data
262
+ else:
263
+ target = []
264
+ score = []
265
+ for i in range(nRepeats):
266
+ # print("repeat %d" % i)
267
+ t, s = self.run(inputs, outputs, mode="sample")
268
+ t = self.tensorToOutput(t)
269
+ target.extend(t)
270
+ score.extend(list(s.data))
271
+ return target, score
272
+
273
+ def run(self, inputs, outputs, target=None, mode="sample"):
274
+ """
275
+ :param mode: "score" returns log p(target|input), "sample" returns target ~ p(-|input)
276
+ :param List[LongTensor] inputs: n_examples * (max_length_input * batch_size)
277
+ :param List[LongTensor] target: max_length_target * batch_size
278
+ """
279
+ assert((mode == "score" and target is not None) or mode == "sample")
280
+
281
+ n_examples = len(inputs)
282
+ max_length_input = [inputs[j].size(0) for j in range(n_examples)]
283
+ max_length_output = [outputs[j].size(0) for j in range(n_examples)]
284
+ max_length_target = target.size(0) if target is not None else 10
285
+ batch_size = inputs[0].size(1)
286
+
287
+ score = Variable(torch.zeros(batch_size))
288
+ inputs_scatter = [Variable(torch.zeros(max_length_input[j], batch_size, self.v_input + 1).scatter_(
289
+ 2, inputs[j][:, :, None], 1)) for j in range(n_examples)] # n_examples * (max_length_input * batch_size * v_input+1)
290
+ outputs_scatter = [Variable(torch.zeros(max_length_output[j], batch_size, self.v_input + 1).scatter_(
291
+ 2, outputs[j][:, :, None], 1)) for j in range(n_examples)] # n_examples * (max_length_output * batch_size * v_input+1)
292
+ if target is not None:
293
+ target_scatter = Variable(torch.zeros(max_length_target,
294
+ batch_size,
295
+ self.v_target + 1).scatter_(2,
296
+ target[:,
297
+ :,
298
+ None],
299
+ 1)) # max_length_target * batch_size * v_target+1
300
+
301
+ # -------------- Input Encoder -------------
302
+
303
+ # n_examples * (max_length_input * batch_size * h_encoder_size)
304
+ input_H = []
305
+ input_embeddings = [] # h for example at INPUT_EOS
306
+ # 0 until (and including) INPUT_EOS, then -inf
307
+ input_attention_mask = []
308
+ for j in range(n_examples):
309
+ active = torch.Tensor(max_length_input[j], batch_size).byte()
310
+ active[0, :] = 1
311
+ state = self.input_encoder_get_init(batch_size)
312
+ hs = []
313
+ for i in range(max_length_input[j]):
314
+ state = self.input_encoder_cell(
315
+ inputs_scatter[j][i, :, :], state)
316
+ if i + 1 < max_length_input[j]:
317
+ active[i + 1, :] = active[i, :] * \
318
+ (inputs[j][i, :] != self.v_input)
319
+ h = self.cell_get_h(state)
320
+ hs.append(h[None, :, :])
321
+ input_H.append(torch.cat(hs, 0))
322
+ embedding_idx = active.sum(0).long() - 1
323
+ embedding = input_H[j].gather(0, Variable(
324
+ embedding_idx[None, :, None].repeat(1, 1, self.h_input_encoder_size)))[0]
325
+ input_embeddings.append(embedding)
326
+ input_attention_mask.append(Variable(active.float().log()))
327
+
328
+ # -------------- Output Encoder -------------
329
+
330
+ def input_attend(j, h_out):
331
+ """
332
+ 'general' attention from https://arxiv.org/pdf/1508.04025.pdf
333
+ :param j: Index of example
334
+ :param h_out: batch_size * h_output_encoder_size
335
+ """
336
+ scores = self.input_A(
337
+ input_H[j].view(
338
+ max_length_input[j] * batch_size,
339
+ self.h_input_encoder_size),
340
+ h_out.view(
341
+ batch_size,
342
+ self.h_output_encoder_size).repeat(
343
+ max_length_input[j],
344
+ 1)).view(
345
+ max_length_input[j],
346
+ batch_size) + input_attention_mask[j]
347
+ c = (F.softmax(scores[:, :, None], dim=0) * input_H[j]).sum(0)
348
+ return c
349
+
350
+ # n_examples * (max_length_input * batch_size * h_encoder_size)
351
+ output_H = []
352
+ output_embeddings = [] # h for example at INPUT_EOS
353
+ # 0 until (and including) INPUT_EOS, then -inf
354
+ output_attention_mask = []
355
+ for j in range(n_examples):
356
+ active = torch.Tensor(max_length_output[j], batch_size).byte()
357
+ active[0, :] = 1
358
+ state = self.output_encoder_get_init(input_embeddings[j])
359
+ hs = []
360
+ h = self.cell_get_h(state)
361
+ for i in range(max_length_output[j]):
362
+ state = self.output_encoder_cell(torch.cat(
363
+ [outputs_scatter[j][i, :, :], input_attend(j, h)], 1), state)
364
+ if i + 1 < max_length_output[j]:
365
+ active[i + 1, :] = active[i, :] * \
366
+ (outputs[j][i, :] != self.v_input)
367
+ h = self.cell_get_h(state)
368
+ hs.append(h[None, :, :])
369
+ output_H.append(torch.cat(hs, 0))
370
+ embedding_idx = active.sum(0).long() - 1
371
+ embedding = output_H[j].gather(0, Variable(
372
+ embedding_idx[None, :, None].repeat(1, 1, self.h_output_encoder_size)))[0]
373
+ output_embeddings.append(embedding)
374
+ output_attention_mask.append(Variable(active.float().log()))
375
+
376
+ # ------------------ Decoder -----------------
377
+
378
+ def output_attend(j, h_dec):
379
+ """
380
+ 'general' attention from https://arxiv.org/pdf/1508.04025.pdf
381
+ :param j: Index of example
382
+ :param h_dec: batch_size * h_decoder_size
383
+ """
384
+ scores = self.output_A(
385
+ output_H[j].view(
386
+ max_length_output[j] * batch_size,
387
+ self.h_output_encoder_size),
388
+ h_dec.view(
389
+ batch_size,
390
+ self.h_decoder_size).repeat(
391
+ max_length_output[j],
392
+ 1)).view(
393
+ max_length_output[j],
394
+ batch_size) + output_attention_mask[j]
395
+ c = (F.softmax(scores[:, :, None], dim=0) * output_H[j]).sum(0)
396
+ return c
397
+
398
+ # Multi-example pooling: Figure 3, https://arxiv.org/pdf/1703.07469.pdf
399
+ target = target if mode == "score" else torch.zeros(
400
+ max_length_target, batch_size).long()
401
+ decoder_states = [
402
+ self.decoder_get_init(
403
+ output_embeddings[j]) for j in range(n_examples)] # P
404
+ active = torch.ones(batch_size).byte()
405
+ for i in range(max_length_target):
406
+ FC = []
407
+ for j in range(n_examples):
408
+ h = self.cell_get_h(decoder_states[j])
409
+ p_aug = torch.cat([h, output_attend(j, h)], 1)
410
+ FC.append(F.tanh(self.W(p_aug)[None, :, :]))
411
+ # batch_size * embedding_size
412
+ m = torch.max(torch.cat(FC, 0), 0)[0]
413
+ logsoftmax = F.log_softmax(self.V(m), dim=1)
414
+ if mode == "sample":
415
+ target[i, :] = torch.multinomial(
416
+ logsoftmax.data.exp(), 1)[:, 0]
417
+ score = score + \
418
+ choose(logsoftmax, target[i, :]) * Variable(active.float())
419
+ active *= (target[i, :] != self.v_target)
420
+ for j in range(n_examples):
421
+ if mode == "score":
422
+ target_char_scatter = target_scatter[i, :, :]
423
+ elif mode == "sample":
424
+ target_char_scatter = Variable(torch.zeros(
425
+ batch_size, self.v_target + 1).scatter_(1, target[i, :, None], 1))
426
+ decoder_states[j] = self.decoder_cell(
427
+ target_char_scatter, decoder_states[j])
428
+ return target, score
429
+
430
+ def inputsToTensors(self, inputss):
431
+ """
432
+ :param inputss: size = nBatch * nExamples
433
+ """
434
+ tensors = []
435
+ for j in range(len(inputss[0])):
436
+ inputs = [x[j] for x in inputss]
437
+ maxlen = max(len(s) for s in inputs)
438
+ t = torch.ones(
439
+ 1 if maxlen == 0 else maxlen + 1,
440
+ len(inputs)).long() * self.v_input
441
+ for i in range(len(inputs)):
442
+ s = inputs[i]
443
+ if len(s) > 0:
444
+ t[:len(s), i] = torch.LongTensor(
445
+ [self.input_vocabulary.index(x) for x in s])
446
+ tensors.append(t)
447
+ return tensors
448
+
449
+ def targetToTensor(self, targets):
450
+ """
451
+ :param targets:
452
+ """
453
+ maxlen = max(len(s) for s in targets)
454
+ t = torch.ones(
455
+ 1 if maxlen == 0 else maxlen + 1,
456
+ len(targets)).long() * self.v_target
457
+ for i in range(len(targets)):
458
+ s = targets[i]
459
+ if len(s) > 0:
460
+ t[:len(s), i] = torch.LongTensor(
461
+ [self.target_vocabulary.index(x) for x in s])
462
+ return t
463
+
464
+ def tensorToOutput(self, tensor):
465
+ """
466
+ :param tensor: max_length * batch_size
467
+ """
468
+ out = []
469
+ for i in range(tensor.size(1)):
470
+ l = tensor[:, i].tolist()
471
+ if l[0] == self.v_target:
472
+ out.append([])
473
+ elif self.v_target in l:
474
+ final = tensor[:, i].tolist().index(self.v_target)
475
+ out.append([self.target_vocabulary[x]
476
+ for x in tensor[:final, i]])
477
+ else:
478
+ out.append([self.target_vocabulary[x] for x in tensor[:, i]])
479
+ return out
dreamcoder/differentiation.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from dreamcoder.utilities import *
4
+
5
+
6
+ class InvalidLoss(Exception):
7
+ pass
8
+
9
+
10
+ class DN(object):
11
+ '''differentiable node: parent object of every differentiable operation'''
12
+
13
+ def __init__(self, arguments):
14
+ self.gradient = None
15
+ if arguments != []:
16
+ self.data = None
17
+ self.arguments = arguments
18
+
19
+ # descendents: every variable that takes this variable as input
20
+ # descendents: [(DN,float)]
21
+ # the additional float parameter is d Descendent / d This
22
+ self.descendents = []
23
+
24
+ self.recalculate()
25
+
26
+ def __str__(self):
27
+ if self.arguments == []:
28
+ return self.name
29
+ return "(%s %s)" % (self.name, " ".join(str(x)
30
+ for x in self.arguments))
31
+
32
+ def __repr__(self):
33
+ return "DN(op = %s, data = %s, grad = %s, #descendents = %d, args = %s)" % (
34
+ self.name, self.data, self.gradient, len(self.descendents), self.arguments)
35
+
36
+ @property
37
+ def derivative(self): return self.differentiate()
38
+
39
+ def differentiate(self):
40
+ if self.gradient is None:
41
+ self.gradient = sum(partial * descendent.differentiate()
42
+ for descendent, partial in self.descendents)
43
+ return self.gradient
44
+
45
+ def zeroEverything(self):
46
+ if self.gradient is None and self.descendents == [] and (
47
+ self.data is None or self.arguments == []):
48
+ return
49
+
50
+ self.gradient = None
51
+ self.descendents = []
52
+ if self.arguments != []:
53
+ self.data = None
54
+
55
+ for x in self.arguments:
56
+ x.zeroEverything()
57
+
58
+ def lightweightRecalculate(self):
59
+ return self.forward(*[a.lightweightRecalculate()
60
+ for a in self.arguments])
61
+
62
+ def recalculate(self):
63
+ if self.data is None:
64
+ inputs = [a.recalculate() for a in self.arguments]
65
+ self.data = self.forward(*inputs)
66
+ # if invalid(self.data):
67
+ # eprint("I am invalid",repr(self))
68
+ # eprint("Here are my inputs",inputs)
69
+ # self.zeroEverything()
70
+ # eprint("Here I am after being zeroed",repr(self))
71
+ # raise Exception('invalid loss')
72
+ #assert valid(self.data)
73
+ partials = self.backward(*inputs)
74
+ for d, a in zip(partials, self.arguments):
75
+ # if invalid(d):
76
+ # eprint("I have an invalid derivative",self)
77
+ # eprint("Inputs",inputs)
78
+ # eprint("partials",partials)
79
+ # raise Exception('invalid derivative')
80
+ a.descendents.append((self, d))
81
+ return self.data
82
+
83
+ def backPropagation(self):
84
+ self.gradient = 1.
85
+ self.recursivelyDifferentiate()
86
+
87
+ def recursivelyDifferentiate(self):
88
+ self.differentiate()
89
+ for x in self.arguments:
90
+ x.recursivelyDifferentiate()
91
+
92
+ def updateNetwork(self):
93
+ self.zeroEverything()
94
+ l = self.recalculate()
95
+ self.backPropagation()
96
+ return l
97
+
98
+ def log(self): return Logarithm(self)
99
+
100
+ def square(self): return Square(self)
101
+
102
+ def exp(self): return Exponentiation(self)
103
+
104
+ def clamp(self, l, u): return Clamp(self, l, u)
105
+
106
+ def __abs__(self): return AbsoluteValue(self)
107
+
108
+ def __add__(self, o): return Addition(self, Placeholder.maybe(o))
109
+
110
+ def __radd__(self, o): return Addition(self, Placeholder.maybe(o))
111
+
112
+ def __sub__(self, o): return Subtraction(self, Placeholder.maybe(o))
113
+
114
+ def __rsub__(self, o): return Subtraction(Placeholder.maybe(o), self)
115
+
116
+ def __mul__(self, o): return Multiplication(self, Placeholder.maybe(o))
117
+
118
+ def __rmul__(self, o): return Multiplication(self, Placeholder.maybe(o))
119
+
120
+ def __neg__(self): return Negation(self)
121
+
122
+ def __truediv__(self, o): return Division(self, Placeholder.maybe(o))
123
+
124
+ def __rtruediv__(self, o): return Division(Placeholder.maybe(o), self)
125
+
126
+ def numericallyVerifyGradients(self, parameters):
127
+ calculatedGradients = [p.derivative for p in parameters]
128
+ e = 0.00001
129
+ for j, p in enumerate(parameters):
130
+ p.data -= e
131
+ y1 = self.lightweightRecalculate()
132
+ p.data += 2 * e
133
+ y2 = self.lightweightRecalculate()
134
+ p.data -= e
135
+ d = (y2 - y1) / (2 * e)
136
+ if abs(calculatedGradients[j] - d) > 0.1:
137
+ eprint(
138
+ "Bad gradient: expected %f, got %f" %
139
+ (d, calculatedGradients[j]))
140
+
141
+ def gradientDescent(
142
+ self,
143
+ parameters,
144
+ _=None,
145
+ lr=0.001,
146
+ steps=10**3,
147
+ update=None):
148
+ for j in range(steps):
149
+ l = self.updateNetwork()
150
+ if update is not None and j % update == 0:
151
+ eprint("LOSS:", l)
152
+ for p in parameters:
153
+ eprint(p.data, '\t', p.derivative)
154
+ if invalid(l):
155
+ raise InvalidLoss()
156
+
157
+ for p in parameters:
158
+ p.data -= lr * p.derivative
159
+ return self.data
160
+
161
+ def restartingOptimize(self, parameters, _=None, attempts=1,
162
+ s=1., decay=0.5, grow=0.1,
163
+ lr=0.1, steps=10**3, update=None):
164
+ ls = []
165
+ for _ in range(attempts):
166
+ for p in parameters:
167
+ p.data = random.random()*10 - 5
168
+ ls.append(
169
+ self.resilientBackPropagation(
170
+ parameters, lr=lr, steps=steps,
171
+ decay=decay, grow=grow))
172
+ return min(ls)
173
+
174
+ def resilientBackPropagation(
175
+ self,
176
+ parameters,
177
+ _=None,
178
+ decay=0.5,
179
+ grow=1.2,
180
+ lr=0.1,
181
+ steps=10**3,
182
+ update=None):
183
+ previousSign = [None] * len(parameters)
184
+ lr = [lr] * len(parameters)
185
+ for j in range(steps):
186
+ l = self.updateNetwork()
187
+
188
+ if update is not None and j % update == 0:
189
+ eprint("LOSS:", l)
190
+ eprint("\t".join(str(p.derivative) for p in parameters))
191
+ if invalid(l):
192
+ raise InvalidLoss()
193
+
194
+ newSigns = [p.derivative > 0 for p in parameters]
195
+ for i, p in enumerate(parameters):
196
+ if p.derivative > 0:
197
+ p.data -= lr[i]
198
+ elif p.derivative < 0:
199
+ p.data += lr[i]
200
+ if previousSign[i] is not None:
201
+ if previousSign[i] == newSigns[i]:
202
+ lr[i] *= grow
203
+ else:
204
+ lr[i] *= decay
205
+ previousSign = newSigns
206
+
207
+ return self.data
208
+
209
+
210
+ class Placeholder(DN):
211
+ COUNTER = 0
212
+
213
+ def __init__(self, initialValue=0., name=None):
214
+ self.data = initialValue
215
+ super(Placeholder, self).__init__([])
216
+ if name is None:
217
+ name = "p_" + str(Placeholder.COUNTER)
218
+ Placeholder.COUNTER += 1
219
+ self.name = name
220
+
221
+ @staticmethod
222
+ def named(namePrefix, initialValue=0.):
223
+ p = Placeholder(initialValue, namePrefix + str(Placeholder.COUNTER))
224
+ Placeholder.COUNTER += 1
225
+ return p
226
+
227
+ def __str__(self):
228
+ return "Placeholder(%s = %s)" % (self.name, self.data)
229
+
230
+ @staticmethod
231
+ def maybe(x):
232
+ if isinstance(x, DN):
233
+ return x
234
+ return Placeholder(float(x))
235
+
236
+ def forward(self): return self.data
237
+
238
+ def backward(self): return []
239
+
240
+
241
+ class Clamp(DN):
242
+ def __init__(self, x, l, u):
243
+ assert u > l
244
+ self.l = l
245
+ self.u = u
246
+ super(Clamp, self).__init__([x])
247
+ self.name = "clamp"
248
+
249
+ def forward(self, x):
250
+ if x > self.u:
251
+ return self.u
252
+ if x < self.l:
253
+ return self.l
254
+ return x
255
+
256
+ def backward(self, x):
257
+ if x > self.u or x < self.l:
258
+ return [0.]
259
+ else:
260
+ return [1.]
261
+
262
+
263
+ class Addition(DN):
264
+ def __init__(self, x, y):
265
+ super(Addition, self).__init__([x, y])
266
+ self.name = '+'
267
+
268
+ def forward(self, x, y): return x + y
269
+
270
+ def backward(self, x, y): return [1., 1.]
271
+
272
+
273
+ class Subtraction(DN):
274
+ def __init__(self, x, y):
275
+ super(Subtraction, self).__init__([x, y])
276
+ self.name = '-'
277
+
278
+ def forward(self, x, y): return x - y
279
+
280
+ def backward(self, x, y): return [1., -1.]
281
+
282
+
283
+ class Negation(DN):
284
+ def __init__(self, x):
285
+ super(Negation, self).__init__([x])
286
+ self.name = '-'
287
+
288
+ def forward(self, x): return -x
289
+
290
+ def backward(self, x): return [-1.]
291
+
292
+
293
+ class AbsoluteValue(DN):
294
+ def __init__(self, x):
295
+ super(AbsoluteValue, self).__init__([x])
296
+ self.name = 'abs'
297
+
298
+ def forward(self, x): return abs(x)
299
+
300
+ def backward(self, x):
301
+ if x > 0:
302
+ return [1.]
303
+ return [-1.]
304
+
305
+
306
+ class Multiplication(DN):
307
+ def __init__(self, x, y):
308
+ super(Multiplication, self).__init__([x, y])
309
+ self.name = '*'
310
+
311
+ def forward(self, x, y): return x * y
312
+
313
+ def backward(self, x, y): return [y, x]
314
+
315
+
316
+ class Division(DN):
317
+ def __init__(self, x, y):
318
+ super(Division, self).__init__([x, y])
319
+ self.name = '/'
320
+
321
+ def forward(self, x, y): return x / y
322
+
323
+ def backward(self, x, y): return [1.0 / y, -x / (y * y)]
324
+
325
+
326
+ class Square(DN):
327
+ def __init__(self, x):
328
+ super(Square, self).__init__([x])
329
+ self.name = 'sq'
330
+
331
+ def forward(self, x): return x * x
332
+
333
+ def backward(self, x): return [2 * x]
334
+
335
+
336
+ class Exponentiation(DN):
337
+ def __init__(self, x):
338
+ super(Exponentiation, self).__init__([x])
339
+ self.name = 'exp'
340
+
341
+ def forward(self, x): return math.exp(x)
342
+
343
+ def backward(self, x): return [math.exp(x)]
344
+
345
+
346
+ class Logarithm(DN):
347
+ def __init__(self, x):
348
+ super(Logarithm, self).__init__([x])
349
+ self.name = 'log'
350
+
351
+ def forward(self, x): return math.log(x)
352
+
353
+ def backward(self, x): return [1. / x]
354
+
355
+
356
+ class LSE(DN):
357
+ def __init__(self, xs):
358
+ super(LSE, self).__init__(xs)
359
+ self.name = 'LSE'
360
+
361
+ def forward(self, *xs):
362
+ m = max(xs)
363
+ return m + math.log(sum(math.exp(y - m) for y in xs))
364
+
365
+ def backward(self, *xs):
366
+ m = max(xs)
367
+ zm = sum(math.exp(x - m) for x in xs)
368
+ return [math.exp(x - m) / zm for x in xs]
369
+
370
+
371
+ if __name__ == "__main__":
372
+ x = Placeholder(10., "x")
373
+ y = Placeholder(2., "y")
374
+ z = x - LSE([x, y])
375
+ z.updateNetwork()
376
+ eprint("dL/dx = %f\tdL/dy = %f" % (x.derivative, y.derivative))
377
+
378
+ x.data = 2.
379
+ y.data = 10.
380
+ z.updateNetwork()
381
+ eprint("dL/dx = %f\tdL/dy = %f" % (x.differentiate(), y.differentiate()))
382
+
383
+ x.data = 2.
384
+ y.data = 2.
385
+ z.updateNetwork()
386
+ eprint("z = ", z.data, z)
387
+ eprint("dL/dx = %f\tdL/dy = %f" % (x.differentiate(), y.differentiate()))
388
+
389
+ loss = -z
390
+ eprint(loss)
391
+
392
+ lr = 0.001
393
+ loss.gradientDescent([x, y], steps=10000, update=1000)
dreamcoder/domains/__init__.py ADDED
File without changes
dreamcoder/domains/arithmetic/__init__.py ADDED
File without changes
dreamcoder/domains/arithmetic/arithmeticPrimitives.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.program import *
2
+ from dreamcoder.type import *
3
+
4
+
5
+ def _addition(x): return lambda y: x + y
6
+
7
+
8
+ def _subtraction(x): return lambda y: x - y
9
+
10
+
11
+ def _division(x): return lambda y: x / y
12
+
13
+
14
+ subtraction = Primitive("-",
15
+ arrow(tint, arrow(tint, tint)),
16
+ _subtraction)
17
+ real_subtraction = Primitive("-.",
18
+ arrow(treal, treal, treal),
19
+ _subtraction)
20
+ addition = Primitive("+",
21
+ arrow(tint, arrow(tint, tint)),
22
+ Curried(_addition))
23
+ real_addition = Primitive("+.",
24
+ arrow(treal, treal, treal),
25
+ _addition)
26
+
27
+
28
+ def _multiplication(x): return lambda y: x * y
29
+
30
+
31
+ multiplication = Primitive("*",
32
+ arrow(tint, arrow(tint, tint)),
33
+ _multiplication)
34
+ real_multiplication = Primitive("*.",
35
+ arrow(treal, treal, treal),
36
+ _multiplication)
37
+ real_division = Primitive("/.",
38
+ arrow(treal, treal, treal),
39
+ _division)
40
+
41
+
42
+ def _power(a): return lambda b: a**b
43
+
44
+
45
+ real_power = Primitive("power",
46
+ arrow(treal, treal, treal),
47
+ _power)
48
+
49
+ k1 = Primitive("1", tint, 1)
50
+ k_negative1 = Primitive("negative_1", tint, -1)
51
+ k0 = Primitive("0", tint, 0)
52
+ for n in range(2,10):
53
+ Primitive(str(n),tint,n)
54
+
55
+ f1 = Primitive("1.", treal, 1.)
56
+ f0 = Primitive("0.", treal, 0)
57
+ real = Primitive("REAL", treal, None)
58
+ fpi = Primitive("pi", treal, 3.14)
dreamcoder/domains/list/__init__.py ADDED
File without changes
dreamcoder/domains/list/listPrimitives.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.program import Primitive, Program
2
+ from dreamcoder.grammar import Grammar
3
+ from dreamcoder.type import tlist, tint, tbool, arrow, t0, t1, t2
4
+
5
+ import math
6
+ from functools import reduce
7
+
8
+
9
+ def _flatten(l): return [x for xs in l for x in xs]
10
+
11
+ def _range(n):
12
+ if n < 100: return list(range(n))
13
+ raise ValueError()
14
+ def _if(c): return lambda t: lambda f: t if c else f
15
+
16
+
17
+ def _and(x): return lambda y: x and y
18
+
19
+
20
+ def _or(x): return lambda y: x or y
21
+
22
+
23
+ def _addition(x): return lambda y: x + y
24
+
25
+
26
+ def _subtraction(x): return lambda y: x - y
27
+
28
+
29
+ def _multiplication(x): return lambda y: x * y
30
+
31
+
32
+ def _negate(x): return -x
33
+
34
+
35
+ def _reverse(x): return list(reversed(x))
36
+
37
+
38
+ def _append(x): return lambda y: x + y
39
+
40
+
41
+ def _cons(x): return lambda y: [x] + y
42
+
43
+
44
+ def _car(x): return x[0]
45
+
46
+
47
+ def _cdr(x): return x[1:]
48
+
49
+
50
+ def _isEmpty(x): return x == []
51
+
52
+
53
+ def _single(x): return [x]
54
+
55
+
56
+ def _slice(x): return lambda y: lambda l: l[x:y]
57
+
58
+
59
+ def _map(f): return lambda l: list(map(f, l))
60
+
61
+
62
+ def _zip(a): return lambda b: lambda f: list(map(lambda x,y: f(x)(y), a, b))
63
+
64
+
65
+ def _mapi(f): return lambda l: list(map(lambda i_x: f(i_x[0])(i_x[1]), enumerate(l)))
66
+
67
+
68
+ def _reduce(f): return lambda x0: lambda l: reduce(lambda a, x: f(a)(x), l, x0)
69
+
70
+
71
+ def _reducei(f): return lambda x0: lambda l: reduce(
72
+ lambda a, t: f(t[0])(a)(t[1]), enumerate(l), x0)
73
+
74
+
75
+ def _fold(l): return lambda x0: lambda f: reduce(
76
+ lambda a, x: f(x)(a), l[::-1], x0)
77
+
78
+
79
+ def _eq(x): return lambda y: x == y
80
+
81
+
82
+ def _eq0(x): return x == 0
83
+
84
+
85
+ def _a1(x): return x + 1
86
+
87
+
88
+ def _d1(x): return x - 1
89
+
90
+
91
+ def _mod(x): return lambda y: x % y
92
+
93
+
94
+ def _not(x): return not x
95
+
96
+
97
+ def _gt(x): return lambda y: x > y
98
+
99
+
100
+ def _index(j): return lambda l: l[j]
101
+
102
+
103
+ def _replace(f): return lambda lnew: lambda lin: _flatten(
104
+ lnew if f(i)(x) else [x] for i, x in enumerate(lin))
105
+
106
+
107
+ def _isPrime(n):
108
+ return n in {
109
+ 2,
110
+ 3,
111
+ 5,
112
+ 7,
113
+ 11,
114
+ 13,
115
+ 17,
116
+ 19,
117
+ 23,
118
+ 29,
119
+ 31,
120
+ 37,
121
+ 41,
122
+ 43,
123
+ 47,
124
+ 53,
125
+ 59,
126
+ 61,
127
+ 67,
128
+ 71,
129
+ 73,
130
+ 79,
131
+ 83,
132
+ 89,
133
+ 97,
134
+ 101,
135
+ 103,
136
+ 107,
137
+ 109,
138
+ 113,
139
+ 127,
140
+ 131,
141
+ 137,
142
+ 139,
143
+ 149,
144
+ 151,
145
+ 157,
146
+ 163,
147
+ 167,
148
+ 173,
149
+ 179,
150
+ 181,
151
+ 191,
152
+ 193,
153
+ 197,
154
+ 199}
155
+
156
+
157
+ def _isSquare(n):
158
+ return int(math.sqrt(n)) ** 2 == n
159
+
160
+
161
+ def _appendmap(f): lambda xs: [y for x in xs for y in f(x)]
162
+
163
+
164
+ def _filter(f): return lambda l: list(filter(f, l))
165
+
166
+
167
+ def _any(f): return lambda l: any(f(x) for x in l)
168
+
169
+
170
+ def _all(f): return lambda l: all(f(x) for x in l)
171
+
172
+
173
+ def _find(x):
174
+ def _inner(l):
175
+ try:
176
+ return l.index(x)
177
+ except ValueError:
178
+ return -1
179
+ return _inner
180
+
181
+
182
+ def _unfold(x): return lambda p: lambda h: lambda n: __unfold(p, f, n, x)
183
+
184
+
185
+ def __unfold(p, f, n, x, recursion_limit=50):
186
+ if recursion_limit <= 0:
187
+ raise RecursionDepthExceeded()
188
+ if p(x):
189
+ return []
190
+ return [f(x)] + __unfold(p, f, n, n(x), recursion_limit - 1)
191
+
192
+
193
+ class RecursionDepthExceeded(Exception):
194
+ pass
195
+
196
+
197
+ def _fix(argument):
198
+ def inner(body):
199
+ recursion_limit = [20]
200
+
201
+ def fix(x):
202
+ def r(z):
203
+ recursion_limit[0] -= 1
204
+ if recursion_limit[0] <= 0:
205
+ raise RecursionDepthExceeded()
206
+ else:
207
+ return fix(z)
208
+
209
+ return body(r)(x)
210
+ return fix(argument)
211
+
212
+ return inner
213
+
214
+
215
+ def curry(f): return lambda x: lambda y: f((x, y))
216
+
217
+
218
+ def _fix2(a1):
219
+ return lambda a2: lambda body: \
220
+ _fix((a1, a2))(lambda r: lambda n_l: body(curry(r))(n_l[0])(n_l[1]))
221
+
222
+
223
+ primitiveRecursion1 = Primitive("fix1",
224
+ arrow(t0,
225
+ arrow(arrow(t0, t1), t0, t1),
226
+ t1),
227
+ _fix)
228
+
229
+ primitiveRecursion2 = Primitive("fix2",
230
+ arrow(t0, t1,
231
+ arrow(arrow(t0, t1, t2), t0, t1, t2),
232
+ t2),
233
+ _fix2)
234
+
235
+
236
+ def _match(l):
237
+ return lambda b: lambda f: b if l == [] else f(l[0])(l[1:])
238
+
239
+
240
+ def primitives():
241
+ return [Primitive(str(j), tint, j) for j in range(6)] + [
242
+ Primitive("empty", tlist(t0), []),
243
+ Primitive("singleton", arrow(t0, tlist(t0)), _single),
244
+ Primitive("range", arrow(tint, tlist(tint)), _range),
245
+ Primitive("++", arrow(tlist(t0), tlist(t0), tlist(t0)), _append),
246
+ # Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map),
247
+ Primitive(
248
+ "mapi",
249
+ arrow(
250
+ arrow(
251
+ tint,
252
+ t0,
253
+ t1),
254
+ tlist(t0),
255
+ tlist(t1)),
256
+ _mapi),
257
+ # Primitive("reduce", arrow(arrow(t1, t0, t1), t1, tlist(t0), t1), _reduce),
258
+ Primitive(
259
+ "reducei",
260
+ arrow(
261
+ arrow(
262
+ tint,
263
+ t1,
264
+ t0,
265
+ t1),
266
+ t1,
267
+ tlist(t0),
268
+ t1),
269
+ _reducei),
270
+
271
+ Primitive("true", tbool, True),
272
+ Primitive("not", arrow(tbool, tbool), _not),
273
+ Primitive("and", arrow(tbool, tbool, tbool), _and),
274
+ Primitive("or", arrow(tbool, tbool, tbool), _or),
275
+ # Primitive("if", arrow(tbool, t0, t0, t0), _if),
276
+
277
+ Primitive("sort", arrow(tlist(tint), tlist(tint)), sorted),
278
+ Primitive("+", arrow(tint, tint, tint), _addition),
279
+ Primitive("*", arrow(tint, tint, tint), _multiplication),
280
+ Primitive("negate", arrow(tint, tint), _negate),
281
+ Primitive("mod", arrow(tint, tint, tint), _mod),
282
+ Primitive("eq?", arrow(tint, tint, tbool), _eq),
283
+ Primitive("gt?", arrow(tint, tint, tbool), _gt),
284
+ Primitive("is-prime", arrow(tint, tbool), _isPrime),
285
+ Primitive("is-square", arrow(tint, tbool), _isSquare),
286
+
287
+ # these are achievable with above primitives, but unlikely
288
+ #Primitive("flatten", arrow(tlist(tlist(t0)), tlist(t0)), _flatten),
289
+ # (lambda (reduce (lambda (lambda (++ $1 $0))) empty $0))
290
+ Primitive("sum", arrow(tlist(tint), tint), sum),
291
+ # (lambda (lambda (reduce (lambda (lambda (+ $0 $1))) 0 $0)))
292
+ Primitive("reverse", arrow(tlist(t0), tlist(t0)), _reverse),
293
+ # (lambda (reduce (lambda (lambda (++ (singleton $0) $1))) empty $0))
294
+ Primitive("all", arrow(arrow(t0, tbool), tlist(t0), tbool), _all),
295
+ # (lambda (lambda (reduce (lambda (lambda (and $0 $1))) true (map $1 $0))))
296
+ Primitive("any", arrow(arrow(t0, tbool), tlist(t0), tbool), _any),
297
+ # (lambda (lambda (reduce (lambda (lambda (or $0 $1))) true (map $1 $0))))
298
+ Primitive("index", arrow(tint, tlist(t0), t0), _index),
299
+ # (lambda (lambda (reducei (lambda (lambda (lambda (if (eq? $1 $4) $0 0)))) 0 $0)))
300
+ Primitive("filter", arrow(arrow(t0, tbool), tlist(t0), tlist(t0)), _filter),
301
+ # (lambda (lambda (reduce (lambda (lambda (++ $1 (if ($3 $0) (singleton $0) empty)))) empty $0)))
302
+ #Primitive("replace", arrow(arrow(tint, t0, tbool), tlist(t0), tlist(t0), tlist(t0)), _replace),
303
+ # (FLATTEN (lambda (lambda (lambda (mapi (lambda (lambda (if ($4 $1 $0) $3 (singleton $1)))) $0)))))
304
+ Primitive("slice", arrow(tint, tint, tlist(t0), tlist(t0)), _slice),
305
+ # (lambda (lambda (lambda (reducei (lambda (lambda (lambda (++ $2 (if (and (or (gt? $1 $5) (eq? $1 $5)) (not (or (gt? $4 $1) (eq? $1 $4)))) (singleton $0) empty))))) empty $0))))
306
+ ]
307
+
308
+
309
+ def basePrimitives():
310
+ return [Primitive(str(j), tint, j) for j in range(6)] + [
311
+ Primitive("*", arrow(tint, tint, tint), _multiplication),
312
+ Primitive("gt?", arrow(tint, tint, tbool), _gt),
313
+ Primitive("is-prime", arrow(tint, tbool), _isPrime),
314
+ Primitive("is-square", arrow(tint, tbool), _isSquare),
315
+ # McCarthy
316
+ Primitive("empty", tlist(t0), []),
317
+ Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
318
+ Primitive("car", arrow(tlist(t0), t0), _car),
319
+ Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr),
320
+ Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty),
321
+ Primitive("if", arrow(tbool, t0, t0, t0), _if),
322
+ Primitive("eq?", arrow(tint, tint, tbool), _eq),
323
+ Primitive("+", arrow(tint, tint, tint), _addition),
324
+ Primitive("-", arrow(tint, tint, tint), _subtraction)
325
+ ]
326
+
327
+ zip_primitive = Primitive("zip", arrow(tlist(t0), tlist(t1), arrow(t0, t1, t2), tlist(t2)), _zip)
328
+
329
+ def bootstrapTarget():
330
+ """These are the primitives that we hope to learn from the bootstrapping procedure"""
331
+ return [
332
+ # learned primitives
333
+ Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map),
334
+ Primitive("unfold", arrow(t0, arrow(t0,tbool), arrow(t0,t1), arrow(t0,t0), tlist(t1)), _unfold),
335
+ Primitive("range", arrow(tint, tlist(tint)), _range),
336
+ Primitive("index", arrow(tint, tlist(t0), t0), _index),
337
+ Primitive("fold", arrow(tlist(t0), t1, arrow(t0, t1, t1), t1), _fold),
338
+ Primitive("length", arrow(tlist(t0), tint), len),
339
+
340
+ # built-ins
341
+ Primitive("if", arrow(tbool, t0, t0, t0), _if),
342
+ Primitive("+", arrow(tint, tint, tint), _addition),
343
+ Primitive("-", arrow(tint, tint, tint), _subtraction),
344
+ Primitive("empty", tlist(t0), []),
345
+ Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
346
+ Primitive("car", arrow(tlist(t0), t0), _car),
347
+ Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr),
348
+ Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty),
349
+ ] + [Primitive(str(j), tint, j) for j in range(2)]
350
+
351
+
352
+ def bootstrapTarget_extra():
353
+ """This is the bootstrap target plus list domain specific stuff"""
354
+ return bootstrapTarget() + [
355
+ Primitive("*", arrow(tint, tint, tint), _multiplication),
356
+ Primitive("mod", arrow(tint, tint, tint), _mod),
357
+ Primitive("gt?", arrow(tint, tint, tbool), _gt),
358
+ Primitive("eq?", arrow(tint, tint, tbool), _eq),
359
+ Primitive("is-prime", arrow(tint, tbool), _isPrime),
360
+ Primitive("is-square", arrow(tint, tbool), _isSquare),
361
+ ]
362
+
363
+ def no_length():
364
+ """this is the primitives without length because one of the reviewers wanted this"""
365
+ return [p for p in bootstrapTarget() if p.name != "length"] + [
366
+ Primitive("*", arrow(tint, tint, tint), _multiplication),
367
+ Primitive("mod", arrow(tint, tint, tint), _mod),
368
+ Primitive("gt?", arrow(tint, tint, tbool), _gt),
369
+ Primitive("eq?", arrow(tint, tint, tbool), _eq),
370
+ Primitive("is-prime", arrow(tint, tbool), _isPrime),
371
+ Primitive("is-square", arrow(tint, tbool), _isSquare),
372
+ ]
373
+
374
+
375
+ def McCarthyPrimitives():
376
+ "These are < primitives provided by 1959 lisp as introduced by McCarthy"
377
+ return [
378
+ Primitive("empty", tlist(t0), []),
379
+ Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
380
+ Primitive("car", arrow(tlist(t0), t0), _car),
381
+ Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr),
382
+ Primitive("empty?", arrow(tlist(t0), tbool), _isEmpty),
383
+ #Primitive("unfold", arrow(t0, arrow(t0,t1), arrow(t0,t0), arrow(t0,tbool), tlist(t1)), _isEmpty),
384
+ #Primitive("1+", arrow(tint,tint),None),
385
+ # Primitive("range", arrow(tint, tlist(tint)), range),
386
+ # Primitive("map", arrow(arrow(t0, t1), tlist(t0), tlist(t1)), _map),
387
+ # Primitive("index", arrow(tint,tlist(t0),t0),None),
388
+ # Primitive("length", arrow(tlist(t0),tint),None),
389
+ primitiveRecursion1,
390
+ #primitiveRecursion2,
391
+ Primitive("gt?", arrow(tint, tint, tbool), _gt),
392
+ Primitive("if", arrow(tbool, t0, t0, t0), _if),
393
+ Primitive("eq?", arrow(tint, tint, tbool), _eq),
394
+ Primitive("+", arrow(tint, tint, tint), _addition),
395
+ Primitive("-", arrow(tint, tint, tint), _subtraction),
396
+ ] + [Primitive(str(j), tint, j) for j in range(2)]
397
+
398
+
399
+ if __name__ == "__main__":
400
+ bootstrapTarget()
401
+ g = Grammar.uniform(McCarthyPrimitives())
402
+ # with open("/home/ellisk/om/ec/experimentOutputs/list_aic=1.0_arity=3_ET=1800_expandFrontier=2.0_it=4_likelihoodModel=all-or-nothing_MF=5_baseline=False_pc=10.0_L=1.0_K=5_rec=False.pickle", "rb") as handle:
403
+ # b = pickle.load(handle).grammars[-1]
404
+ # print b
405
+
406
+ p = Program.parse(
407
+ "(lambda (lambda (lambda (if (empty? $0) empty (cons (+ (car $1) (car $0)) ($2 (cdr $1) (cdr $0)))))))")
408
+ t = arrow(tlist(tint), tlist(tint), tlist(tint)) # ,tlist(tbool))
409
+ print(g.logLikelihood(arrow(t, t), p))
410
+ assert False
411
+ print(b.logLikelihood(arrow(t, t), p))
412
+
413
+ # p = Program.parse("""(lambda (lambda
414
+ # (unfold 0
415
+ # (lambda (+ (index $0 $2) (index $0 $1)))
416
+ # (lambda (1+ $0))
417
+ # (lambda (eq? $0 (length $1))))))
418
+ # """)
419
+ p = Program.parse("""(lambda (lambda
420
+ (map (lambda (+ (index $0 $2) (index $0 $1))) (range (length $0)) )))""")
421
+ # .replace("unfold", "#(lambda (lambda (lambda (lambda (fix1 $0 (lambda (lambda (#(lambda (lambda (lambda (if $0 empty (cons $1 $2))))) ($1 ($3 $0)) ($4 $0) ($5 $0)))))))))").\
422
+ # replace("length", "#(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ ($1 (cdr $0)) 1))))))").\
423
+ # replace("forloop", "(#(lambda (lambda (lambda (lambda (fix1 $0 (lambda (lambda (#(lambda (lambda (lambda (if $0 empty (cons $1 $2))))) ($1 ($3 $0)) ($4 $0) ($5 $0))))))))) (lambda (#(eq? 0) $0)) $0 (lambda (#(lambda (- $0 1)) $0)))").\
424
+ # replace("inc","#(lambda (+ $0 1))").\
425
+ # replace("drop","#(lambda (lambda (fix2 $0 $1 (lambda (lambda (lambda (if
426
+ # (#(eq? 0) $1) $0 (cdr ($2 (- $1 1) $0)))))))))"))
427
+ print(p)
428
+ print(g.logLikelihood(t, p))
429
+ assert False
430
+
431
+ print("??")
432
+ p = Program.parse(
433
+ "#(lambda (#(lambda (lambda (lambda (fix1 $0 (lambda (lambda (if (empty? $0) $3 ($4 (car $0) ($1 (cdr $0)))))))))) (lambda $1) 1))")
434
+ for j in range(10):
435
+ l = list(range(j))
436
+ print(l, p.evaluate([])(lambda x: x * 2)(l))
437
+ print()
438
+ print()
439
+
440
+ print("multiply")
441
+ p = Program.parse(
442
+ "(lambda (lambda (lambda (if (eq? $0 0) 0 (+ $1 ($2 $1 (- $0 1)))))))")
443
+ print(g.logLikelihood(arrow(arrow(tint, tint, tint), tint, tint, tint), p))
444
+ print()
445
+
446
+ print("take until 0")
447
+ p = Program.parse("(lambda (lambda (if (eq? $1 0) empty (cons $1 $0))))")
448
+ print(g.logLikelihood(arrow(tint, tlist(tint), tlist(tint)), p))
449
+ print()
450
+
451
+ print("countdown primitive")
452
+ p = Program.parse(
453
+ "(lambda (lambda (if (eq? $0 0) empty (cons (+ $0 1) ($1 (- $0 1))))))")
454
+ print(
455
+ g.logLikelihood(
456
+ arrow(
457
+ arrow(
458
+ tint, tlist(tint)), arrow(
459
+ tint, tlist(tint))), p))
460
+ print(_fix(9)(p.evaluate([])))
461
+ print("countdown w/ better primitives")
462
+ p = Program.parse(
463
+ "(lambda (lambda (if (eq0 $0) empty (cons (+1 $0) ($1 (-1 $0))))))")
464
+ print(
465
+ g.logLikelihood(
466
+ arrow(
467
+ arrow(
468
+ tint, tlist(tint)), arrow(
469
+ tint, tlist(tint))), p))
470
+
471
+ print()
472
+
473
+ print("prepend zeros")
474
+ p = Program.parse(
475
+ "(lambda (lambda (lambda (if (eq? $1 0) $0 (cons 0 ($2 (- $1 1) $0))))))")
476
+ print(
477
+ g.logLikelihood(
478
+ arrow(
479
+ arrow(
480
+ tint,
481
+ tlist(tint),
482
+ tlist(tint)),
483
+ tint,
484
+ tlist(tint),
485
+ tlist(tint)),
486
+ p))
487
+ print()
488
+ assert False
489
+
490
+ p = Program.parse(
491
+ "(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ 1 ($1 (cdr $0))))))))")
492
+ print(p.evaluate([])(list(range(17))))
493
+ print(g.logLikelihood(arrow(tlist(tbool), tint), p))
494
+
495
+ p = Program.parse(
496
+ "(lambda (lambda (if (empty? $0) 0 (+ 1 ($1 (cdr $0))))))")
497
+ print(
498
+ g.logLikelihood(
499
+ arrow(
500
+ arrow(
501
+ tlist(tbool), tint), arrow(
502
+ tlist(tbool), tint)), p))
503
+
504
+ p = Program.parse(
505
+ "(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ (car $0) ($1 (cdr $0))))))))")
506
+
507
+ print(p.evaluate([])(list(range(4))))
508
+ print(g.logLikelihood(arrow(tlist(tint), tint), p))
509
+
510
+ p = Program.parse(
511
+ "(lambda (lambda (if (empty? $0) 0 (+ (car $0) ($1 (cdr $0))))))")
512
+ print(p)
513
+ print(
514
+ g.logLikelihood(
515
+ arrow(
516
+ arrow(
517
+ tlist(tint),
518
+ tint),
519
+ tlist(tint),
520
+ tint),
521
+ p))
522
+
523
+ print("take")
524
+ p = Program.parse(
525
+ "(lambda (lambda (lambda (if (eq? $1 0) empty (cons (car $0) ($2 (- $1 1) (cdr $0)))))))")
526
+ print(p)
527
+ print(
528
+ g.logLikelihood(
529
+ arrow(
530
+ arrow(
531
+ tint,
532
+ tlist(tint),
533
+ tlist(tint)),
534
+ tint,
535
+ tlist(tint),
536
+ tlist(tint)),
537
+ p))
538
+ assert False
539
+
540
+ print(p.evaluate([])(list(range(4))))
541
+ print(g.logLikelihood(arrow(tlist(tint), tlist(tint)), p))
542
+
543
+ p = Program.parse(
544
+ """(lambda (fix (lambda (lambda (match $0 0 (lambda (lambda (+ $1 ($3 $0))))))) $0))""")
545
+ print(p.evaluate([])(list(range(4))))
546
+ print(g.logLikelihood(arrow(tlist(tint), tint), p))
dreamcoder/domains/list/main.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import defaultdict
3
+ import json
4
+ import math
5
+ import os
6
+ import datetime
7
+
8
+ from dreamcoder.dreamcoder import explorationCompression
9
+ from dreamcoder.utilities import eprint, flatten, testTrainSplit
10
+ from dreamcoder.grammar import Grammar
11
+ from dreamcoder.task import Task
12
+ from dreamcoder.type import Context, arrow, tbool, tlist, tint, t0, UnificationFailure
13
+ from dreamcoder.domains.list.listPrimitives import basePrimitives, primitives, McCarthyPrimitives, bootstrapTarget_extra, no_length
14
+ from dreamcoder.domains.list.makeListTasks import make_list_bootstrap_tasks, sortBootstrap, EASYLISTTASKS
15
+
16
+
17
+ def retrieveJSONTasks(filename, features=False):
18
+ """
19
+ For JSON of the form:
20
+ {"name": str,
21
+ "type": {"input" : bool|int|list-of-bool|list-of-int,
22
+ "output": bool|int|list-of-bool|list-of-int},
23
+ "examples": [{"i": data, "o": data}]}
24
+ """
25
+ with open(filename, "r") as f:
26
+ loaded = json.load(f)
27
+ TP = {
28
+ "bool": tbool,
29
+ "int": tint,
30
+ "list-of-bool": tlist(tbool),
31
+ "list-of-int": tlist(tint),
32
+ }
33
+ return [Task(
34
+ item["name"],
35
+ arrow(TP[item["type"]["input"]], TP[item["type"]["output"]]),
36
+ [((ex["i"],), ex["o"]) for ex in item["examples"]],
37
+ features=(None if not features else list_features(
38
+ [((ex["i"],), ex["o"]) for ex in item["examples"]])),
39
+ cache=False,
40
+ ) for item in loaded]
41
+
42
+
43
+ def list_features(examples):
44
+ if any(isinstance(i, int) for (i,), _ in examples):
45
+ # obtain features for number inputs as list of numbers
46
+ examples = [(([i],), o) for (i,), o in examples]
47
+ elif any(not isinstance(i, list) for (i,), _ in examples):
48
+ # can't handle non-lists
49
+ return []
50
+ elif any(isinstance(x, list) for (xs,), _ in examples for x in xs):
51
+ # nested lists are hard to extract features for, so we'll
52
+ # obtain features as if flattened
53
+ examples = [(([x for xs in ys for x in xs],), o)
54
+ for (ys,), o in examples]
55
+
56
+ # assume all tasks have the same number of examples
57
+ # and all inputs are lists
58
+ features = []
59
+ ot = type(examples[0][1])
60
+
61
+ def mean(l): return 0 if not l else sum(l) / len(l)
62
+ imean = [mean(i) for (i,), o in examples]
63
+ ivar = [sum((v - imean[idx])**2
64
+ for v in examples[idx][0][0])
65
+ for idx in range(len(examples))]
66
+
67
+ # DISABLED length of each input and output
68
+ # total difference between length of input and output
69
+ # DISABLED normalized count of numbers in input but not in output
70
+ # total normalized count of numbers in input but not in output
71
+ # total difference between means of input and output
72
+ # total difference between variances of input and output
73
+ # output type (-1=bool, 0=int, 1=list)
74
+ # DISABLED outputs if integers, else -1s
75
+ # DISABLED outputs if bools (-1/1), else 0s
76
+ if ot == list: # lists of ints or bools
77
+ omean = [mean(o) for (i,), o in examples]
78
+ ovar = [sum((v - omean[idx])**2
79
+ for v in examples[idx][1])
80
+ for idx in range(len(examples))]
81
+
82
+ def cntr(
83
+ l, o): return 0 if not l else len(
84
+ set(l).difference(
85
+ set(o))) / len(l)
86
+ cnt_not_in_output = [cntr(i, o) for (i,), o in examples]
87
+
88
+ #features += [len(i) for (i,), o in examples]
89
+ #features += [len(o) for (i,), o in examples]
90
+ features.append(sum(len(i) - len(o) for (i,), o in examples))
91
+ #features += cnt_not_int_output
92
+ features.append(sum(cnt_not_in_output))
93
+ features.append(sum(om - im for im, om in zip(imean, omean)))
94
+ features.append(sum(ov - iv for iv, ov in zip(ivar, ovar)))
95
+ features.append(1)
96
+ # features += [-1 for _ in examples]
97
+ # features += [0 for _ in examples]
98
+ elif ot == bool:
99
+ outs = [o for (i,), o in examples]
100
+
101
+ #features += [len(i) for (i,), o in examples]
102
+ #features += [-1 for _ in examples]
103
+ features.append(sum(len(i) for (i,), o in examples))
104
+ #features += [0 for _ in examples]
105
+ features.append(0)
106
+ features.append(sum(imean))
107
+ features.append(sum(ivar))
108
+ features.append(-1)
109
+ # features += [-1 for _ in examples]
110
+ # features += [1 if o else -1 for o in outs]
111
+ else: # int
112
+ def cntr(
113
+ l, o): return 0 if not l else len(
114
+ set(l).difference(
115
+ set(o))) / len(l)
116
+ cnt_not_in_output = [cntr(i, [o]) for (i,), o in examples]
117
+ outs = [o for (i,), o in examples]
118
+
119
+ #features += [len(i) for (i,), o in examples]
120
+ #features += [1 for (i,), o in examples]
121
+ features.append(sum(len(i) for (i,), o in examples))
122
+ #features += cnt_not_int_output
123
+ features.append(sum(cnt_not_in_output))
124
+ features.append(sum(o - im for im, o in zip(imean, outs)))
125
+ features.append(sum(ivar))
126
+ features.append(0)
127
+ # features += outs
128
+ # features += [0 for _ in examples]
129
+
130
+ return features
131
+
132
+
133
+ def isListFunction(tp):
134
+ try:
135
+ Context().unify(tp, arrow(tlist(tint), t0))
136
+ return True
137
+ except UnificationFailure:
138
+ return False
139
+
140
+
141
+ def isIntFunction(tp):
142
+ try:
143
+ Context().unify(tp, arrow(tint, t0))
144
+ return True
145
+ except UnificationFailure:
146
+ return False
147
+
148
+ try:
149
+ from dreamcoder.recognition import RecurrentFeatureExtractor
150
+ class LearnedFeatureExtractor(RecurrentFeatureExtractor):
151
+ H = 64
152
+
153
+ special = None
154
+
155
+ def tokenize(self, examples):
156
+ def sanitize(l): return [z if z in self.lexicon else "?"
157
+ for z_ in l
158
+ for z in (z_ if isinstance(z_, list) else [z_])]
159
+
160
+ tokenized = []
161
+ for xs, y in examples:
162
+ if isinstance(y, list):
163
+ y = ["LIST_START"] + y + ["LIST_END"]
164
+ else:
165
+ y = [y]
166
+ y = sanitize(y)
167
+ if len(y) > self.maximumLength:
168
+ return None
169
+
170
+ serializedInputs = []
171
+ for xi, x in enumerate(xs):
172
+ if isinstance(x, list):
173
+ x = ["LIST_START"] + x + ["LIST_END"]
174
+ else:
175
+ x = [x]
176
+ x = sanitize(x)
177
+ if len(x) > self.maximumLength:
178
+ return None
179
+ serializedInputs.append(x)
180
+
181
+ tokenized.append((tuple(serializedInputs), y))
182
+
183
+ return tokenized
184
+
185
+ def __init__(self, tasks, testingTasks=[], cuda=False):
186
+ self.lexicon = set(flatten((t.examples for t in tasks + testingTasks), abort=lambda x: isinstance(
187
+ x, str))).union({"LIST_START", "LIST_END", "?"})
188
+
189
+ # Calculate the maximum length
190
+ self.maximumLength = float('inf') # Believe it or not this is actually important to have here
191
+ self.maximumLength = max(len(l)
192
+ for t in tasks + testingTasks
193
+ for xs, y in self.tokenize(t.examples)
194
+ for l in [y] + [x for x in xs])
195
+
196
+ self.recomputeTasks = True
197
+
198
+ super(
199
+ LearnedFeatureExtractor,
200
+ self).__init__(
201
+ lexicon=list(
202
+ self.lexicon),
203
+ tasks=tasks,
204
+ cuda=cuda,
205
+ H=self.H,
206
+ bidirectional=True)
207
+ except: pass
208
+
209
+ def train_necessary(t):
210
+ if t.name in {"head", "is-primes", "len", "pop", "repeat-many", "tail", "keep primes", "keep squares"}:
211
+ return True
212
+ if any(t.name.startswith(x) for x in {
213
+ "add-k", "append-k", "bool-identify-geq-k", "count-k", "drop-k",
214
+ "empty", "evens", "has-k", "index-k", "is-mod-k", "kth-largest",
215
+ "kth-smallest", "modulo-k", "mult-k", "remove-index-k",
216
+ "remove-mod-k", "repeat-k", "replace-all-with-index-k", "rotate-k",
217
+ "slice-k-n", "take-k",
218
+ }):
219
+ return "some"
220
+ return False
221
+
222
+
223
+ def list_options(parser):
224
+ parser.add_argument(
225
+ "--noMap", action="store_true", default=False,
226
+ help="Disable built-in map primitive")
227
+ parser.add_argument(
228
+ "--noUnfold", action="store_true", default=False,
229
+ help="Disable built-in unfold primitive")
230
+ parser.add_argument(
231
+ "--noLength", action="store_true", default=False,
232
+ help="Disable built-in length primitive")
233
+ parser.add_argument(
234
+ "--dataset",
235
+ type=str,
236
+ default="Lucas-old",
237
+ choices=[
238
+ "bootstrap",
239
+ "sorting",
240
+ "Lucas-old",
241
+ "Lucas-depth1",
242
+ "Lucas-depth2",
243
+ "Lucas-depth3"])
244
+ parser.add_argument("--maxTasks", type=int,
245
+ default=None,
246
+ help="truncate tasks to fit within this boundary")
247
+ parser.add_argument("--primitives",
248
+ default="common",
249
+ help="Which primitive set to use",
250
+ choices=["McCarthy", "base", "rich", "common", "noLength"])
251
+ parser.add_argument("--extractor", type=str,
252
+ choices=["hand", "deep", "learned"],
253
+ default="learned")
254
+ parser.add_argument("--split", metavar="TRAIN_RATIO",
255
+ type=float,
256
+ help="split test/train")
257
+ parser.add_argument("-H", "--hidden", type=int,
258
+ default=64,
259
+ help="number of hidden units")
260
+ parser.add_argument("--random-seed", type=int, default=17)
261
+
262
+
263
+ def main(args):
264
+ """
265
+ Takes the return value of the `commandlineArguments()` function as input and
266
+ trains/tests the model on manipulating sequences of numbers.
267
+ """
268
+ random.seed(args.pop("random_seed"))
269
+
270
+ dataset = args.pop("dataset")
271
+ tasks = {
272
+ "Lucas-old": lambda: retrieveJSONTasks("data/list_tasks.json") + sortBootstrap(),
273
+ "bootstrap": make_list_bootstrap_tasks,
274
+ "sorting": sortBootstrap,
275
+ "Lucas-depth1": lambda: retrieveJSONTasks("data/list_tasks2.json")[:105],
276
+ "Lucas-depth2": lambda: retrieveJSONTasks("data/list_tasks2.json")[:4928],
277
+ "Lucas-depth3": lambda: retrieveJSONTasks("data/list_tasks2.json"),
278
+ }[dataset]()
279
+
280
+ maxTasks = args.pop("maxTasks")
281
+ if maxTasks and len(tasks) > maxTasks:
282
+ necessaryTasks = [] # maxTasks will not consider these
283
+ if dataset.startswith("Lucas2.0") and dataset != "Lucas2.0-depth1":
284
+ necessaryTasks = tasks[:105]
285
+
286
+ eprint("Unwilling to handle {} tasks, truncating..".format(len(tasks)))
287
+ random.shuffle(tasks)
288
+ del tasks[maxTasks:]
289
+ tasks = necessaryTasks + tasks
290
+
291
+ if dataset.startswith("Lucas"):
292
+ # extra tasks for filter
293
+ tasks.extend([
294
+ Task("remove empty lists",
295
+ arrow(tlist(tlist(tbool)), tlist(tlist(tbool))),
296
+ [((ls,), list(filter(lambda l: len(l) > 0, ls)))
297
+ for _ in range(15)
298
+ for ls in [[[random.random() < 0.5 for _ in range(random.randint(0, 3))]
299
+ for _ in range(4)]]]),
300
+ Task("keep squares",
301
+ arrow(tlist(tint), tlist(tint)),
302
+ [((xs,), list(filter(lambda x: int(math.sqrt(x)) ** 2 == x,
303
+ xs)))
304
+ for _ in range(15)
305
+ for xs in [[random.choice([0, 1, 4, 9, 16, 25])
306
+ if random.random() < 0.5
307
+ else random.randint(0, 9)
308
+ for _ in range(7)]]]),
309
+ Task("keep primes",
310
+ arrow(tlist(tint), tlist(tint)),
311
+ [((xs,), list(filter(lambda x: x in {2, 3, 5, 7, 11, 13, 17,
312
+ 19, 23, 29, 31, 37}, xs)))
313
+ for _ in range(15)
314
+ for xs in [[random.choice([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37])
315
+ if random.random() < 0.5
316
+ else random.randint(0, 9)
317
+ for _ in range(7)]]]),
318
+ ])
319
+ for i in range(4):
320
+ tasks.extend([
321
+ Task("keep eq %s" % i,
322
+ arrow(tlist(tint), tlist(tint)),
323
+ [((xs,), list(filter(lambda x: x == i, xs)))
324
+ for _ in range(15)
325
+ for xs in [[random.randint(0, 6) for _ in range(5)]]]),
326
+ Task("remove eq %s" % i,
327
+ arrow(tlist(tint), tlist(tint)),
328
+ [((xs,), list(filter(lambda x: x != i, xs)))
329
+ for _ in range(15)
330
+ for xs in [[random.randint(0, 6) for _ in range(5)]]]),
331
+ Task("keep gt %s" % i,
332
+ arrow(tlist(tint), tlist(tint)),
333
+ [((xs,), list(filter(lambda x: x > i, xs)))
334
+ for _ in range(15)
335
+ for xs in [[random.randint(0, 6) for _ in range(5)]]]),
336
+ Task("remove gt %s" % i,
337
+ arrow(tlist(tint), tlist(tint)),
338
+ [((xs,), list(filter(lambda x: not x > i, xs)))
339
+ for _ in range(15)
340
+ for xs in [[random.randint(0, 6) for _ in range(5)]]])
341
+ ])
342
+
343
+ def isIdentityTask(t):
344
+ return all( len(xs) == 1 and xs[0] == y for xs, y in t.examples )
345
+ eprint("Removed", sum(isIdentityTask(t) for t in tasks), "tasks that were just the identity function")
346
+ tasks = [t for t in tasks if not isIdentityTask(t) ]
347
+
348
+ prims = {"base": basePrimitives,
349
+ "McCarthy": McCarthyPrimitives,
350
+ "common": bootstrapTarget_extra,
351
+ "noLength": no_length,
352
+ "rich": primitives}[args.pop("primitives")]()
353
+ haveLength = not args.pop("noLength")
354
+ haveMap = not args.pop("noMap")
355
+ haveUnfold = not args.pop("noUnfold")
356
+ eprint(f"Including map as a primitive? {haveMap}")
357
+ eprint(f"Including length as a primitive? {haveLength}")
358
+ eprint(f"Including unfold as a primitive? {haveUnfold}")
359
+ baseGrammar = Grammar.uniform([p
360
+ for p in prims
361
+ if (p.name != "map" or haveMap) and \
362
+ (p.name != "unfold" or haveUnfold) and \
363
+ (p.name != "length" or haveLength)])
364
+
365
+ extractor = {
366
+ "learned": LearnedFeatureExtractor,
367
+ }[args.pop("extractor")]
368
+ extractor.H = args.pop("hidden")
369
+
370
+ timestamp = datetime.datetime.now().isoformat()
371
+ outputDirectory = "experimentOutputs/list/%s"%timestamp
372
+ os.system("mkdir -p %s"%outputDirectory)
373
+
374
+ args.update({
375
+ "featureExtractor": extractor,
376
+ "outputPrefix": "%s/list"%outputDirectory,
377
+ "evaluationTimeout": 0.0005,
378
+ })
379
+
380
+
381
+ eprint("Got {} list tasks".format(len(tasks)))
382
+ split = args.pop("split")
383
+ if split:
384
+ train_some = defaultdict(list)
385
+ for t in tasks:
386
+ necessary = train_necessary(t)
387
+ if not necessary:
388
+ continue
389
+ if necessary == "some":
390
+ train_some[t.name.split()[0]].append(t)
391
+ else:
392
+ t.mustTrain = True
393
+ for k in sorted(train_some):
394
+ ts = train_some[k]
395
+ random.shuffle(ts)
396
+ ts.pop().mustTrain = True
397
+
398
+ test, train = testTrainSplit(tasks, split)
399
+ if True:
400
+ test = [t for t in test
401
+ if t.name not in EASYLISTTASKS]
402
+
403
+ eprint(
404
+ "Alotted {} tasks for training and {} for testing".format(
405
+ len(train), len(test)))
406
+ else:
407
+ train = tasks
408
+ test = []
409
+
410
+ explorationCompression(baseGrammar, train, testingTasks=test, **args)
dreamcoder/domains/list/makeListTasks.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from dreamcoder.type import *
4
+ from dreamcoder.task import Task
5
+ from dreamcoder.utilities import eprint, hashable
6
+
7
+ from random import randint, random, seed
8
+ from itertools import product
9
+
10
+ # Excluded routines either impossible or astronomically improbable
11
+ # I'm cutting these off at ~20 nats in learned grammars.
12
+ EXCLUDES = {
13
+ "dedup",
14
+ "intersperse-k",
15
+ "pow-base-k",
16
+ "prime",
17
+ "replace-all-k-with-n",
18
+ "replace-index-k-with-n",
19
+ "uniq",
20
+ }
21
+
22
+ # These are tasks that are easy (solved from base DSL) and also uninteresting
23
+ # We exclude these from the test set
24
+ EASYLISTTASKS = {
25
+ "add-k with k=2",
26
+ "bool-identify-geq-k with k=2",
27
+ "bool-identify-geq-k with k=3",
28
+ "bool-identify-is-mod-k with k=1",
29
+ "bool-identify-is-prime",
30
+ "bool-identify-k with k=0",
31
+ "bool-identify-k with k=1",
32
+ "bool-identify-k with k=2",
33
+ "caesar-cipher-k-modulo-n with k=3 and n=2",
34
+ "drop-k with k=1",
35
+ "drop-k with k=2",
36
+ "drop-k with k=4",
37
+ "index-head",
38
+ "index-k with k=2",
39
+ "index-k with k=4",
40
+ "is-mod-k with k=1",
41
+ "is-odds",
42
+ "is-squares",
43
+ "pow-k with k=2",
44
+ "pow-k with k=3",
45
+ "prepend-index-k with k=3",
46
+ "prepend-index-k with k=5",
47
+ "prepend-k with k=1",
48
+ "prepend-k with k=2",
49
+ "prepend-k with k=3",
50
+ "remove-index-k with k=1",
51
+ "replace-all-with-index-k with k=2",
52
+ "replace-all-with-index-k with k=3",
53
+ "slice-k-n with k=1 and n=2",
54
+ "slice-k-n with k=2 and n=1",
55
+ "slice-k-n with k=3 and n=1",
56
+ }
57
+
58
+ def make_list_task(name, examples, **params):
59
+ input_type = guess_type([i for (i,), _ in examples])
60
+ output_type = guess_type([o for _, o in examples])
61
+
62
+ # We can internally handle lists of bools.
63
+ # We explicitly create these by modifying existing routines.
64
+ if name.startswith("identify"):
65
+ boolexamples = [((i,), list(map(bool, o))) for (i,), o in examples]
66
+ yield from make_list_task("bool-" + name, boolexamples, **params)
67
+ # for now, we'll stick with the boolean-only tasks and not have a copy
68
+ # for integers.
69
+ return
70
+
71
+ program_type = arrow(input_type, output_type)
72
+ cache = all(hashable(x) for x in examples)
73
+
74
+ if params:
75
+ eq_params = ["{}={}".format(k, v) for k, v in params.items()]
76
+ if len(eq_params) == 1:
77
+ ext = eq_params[0]
78
+ elif len(eq_params) == 2:
79
+ ext = "{} and {}".format(*eq_params)
80
+ else:
81
+ ext = ", ".join(eq_params[:-1])
82
+ ext = "{}, and {}".format(ext, eq_params[-1])
83
+ name += " with " + ext
84
+
85
+ yield Task(name, program_type, examples, cache=cache)
86
+
87
+
88
+ def make_list_tasks(n_examples):
89
+ import listroutines as lr
90
+
91
+ for routine in lr.find(count=100): # all routines
92
+ if routine.id in EXCLUDES:
93
+ continue
94
+ if routine.is_parametric():
95
+ keys = list(routine.example_params()[0].keys())
96
+ for params in map(lambda values: dict(zip(keys, values)),
97
+ product(range(6), repeat=len(keys))):
98
+ try:
99
+ if routine.id == "rotate-k":
100
+ # rotate-k is hard if list is smaller than k
101
+ k = params["k"]
102
+ if k < 1:
103
+ continue
104
+ inps = []
105
+ for _ in range(n_examples):
106
+ r = randint(abs(k) + 1, 17)
107
+ inp = routine.gen(len=r, **params)[0]
108
+ inps.append(inp)
109
+ else:
110
+ inps = routine.gen(count=n_examples, **params)
111
+ examples = [((inp,), routine.eval(inp, **params))
112
+ for inp in inps]
113
+ yield from make_list_task(routine.id, examples, **params)
114
+ except lr.APIError: # invalid params
115
+ continue
116
+ else:
117
+ inps = routine.examples()
118
+ if len(inps) > n_examples:
119
+ inps = inps[:n_examples]
120
+ elif len(inps) < n_examples:
121
+ inps += routine.gen(count=(n_examples - len(inps)))
122
+ examples = [((inp,), routine.eval(inp)) for inp in inps]
123
+ yield from make_list_task(routine.id, examples)
124
+
125
+
126
+ def make_list_bootstrap_tasks():
127
+ seed(42)
128
+
129
+ def suffixes(l):
130
+ if l == []:
131
+ return []
132
+ else:
133
+ return [l[1:]] + suffixes(l[1:])
134
+
135
+ def flip(): return random() > 0.5
136
+
137
+ def randomSuffix():
138
+ return [randint(0, 9) for _ in range(randint(1, 4))]
139
+
140
+ def randomList(minimum=0, minimumLength=4, maximumLength=6):
141
+ return [randint(minimum, 9) for _ in range(randint(minimumLength, maximumLength))]
142
+
143
+ def randomListOfLists():
144
+ return [randomSuffix() for _ in range(randint(2, 4))]
145
+
146
+ def randomListOfLists_bool(l=None):
147
+ if l is None:
148
+ l = randint(4, 7)
149
+ return [randomBooleanList() for _ in range(l)]
150
+
151
+ def randomBooleanList():
152
+ return [flip() for _ in range(randint(4, 7))]
153
+
154
+ # Reliably learned in under a minute; always triggers learning of length
155
+ # primitive
156
+ lengthBootstrap = [
157
+ # Task("length bool", arrow(tlist(tbool), tint),
158
+ # [((l,), len(l))
159
+ # for _ in range(10)
160
+ # for l in [[flip() for _ in range(randint(0, 10))]]]),
161
+ Task("length int", arrow(tlist(tint), tint),
162
+ [((l,), len(l))
163
+ for _ in range(10)
164
+ for l in [randomList()]]),
165
+ Task("map length", arrow(tlist(tlist(tint)), tlist(tint)),
166
+ [((xss,), [len(xs) for xs in xss])
167
+ for _ in range(10)
168
+ for xss in [randomListOfLists()] ])
169
+ ]
170
+
171
+ # Encourages learning of unfolding
172
+ unfoldBootstrap = [
173
+ Task("countdown", arrow(tint, tlist(tint)),
174
+ [((n,), list(range(n + 1, 1, -1)))
175
+ for n in range(10)]),
176
+ Task("weird count", arrow(tint, tlist(tint)),
177
+ [((n,), list(range(-n,0,-1)))
178
+ for n in range(-10,0) ]),
179
+ Task("take every other", arrow(tlist(tint),tlist(tint)),
180
+ [((l,), [x for j,x in enumerate(l) if j%2 == 0])
181
+ for _ in range(9)
182
+ for l in [ [randint(0, 9) for _ in range(randint(1,4)*2)] ] ] + [(([],),[])]),
183
+ # Task("stutter every other", arrow(tlist(tint),tlist(tint)),
184
+ # [((l,), [l[int(j/2)] for j in range(len(l)) ])
185
+ # for _ in range(10)
186
+ # for l in [ [randint(0, 9) for _ in range(randint(1,4)*2)] ] ]),
187
+ # Task("take until 3 reached", arrow(tlist(tint),tlist(tint)),
188
+ # [((p + [3] + s,),p)
189
+ # for _ in range(10)
190
+ # for p in [ [z for z in randomList()[:5] if z != 3 ]]
191
+ # for s in [randomList()] ]),
192
+ Task("drop last element", arrow(tlist(tint),tlist(tint)),
193
+ [((l,), l[:-1])
194
+ for _ in range(10)
195
+ for l in [ [randint(0, 9) for _ in range(randint(2,5))] ] ]),
196
+ # Task("suffixes", arrow(tlist(tint), tlist(tlist(tint))),
197
+ # [((l,), suffixes(l))
198
+ # for _ in range(10)
199
+ # for l in [randomList()]]),
200
+ Task("range", arrow(tint, tlist(tint)),
201
+ [((n,), list(range(n)))
202
+ for n in range(10)]),
203
+ Task("range inclusive", arrow(tint, tlist(tint)),
204
+ [((n,), list(range(n + 1)))
205
+ for n in range(10)]),
206
+ # Task("range inclusive+1", arrow(tint, tlist(tint)),
207
+ # [((n,), list(range(n + 2)))
208
+ # for n in range(10)]),
209
+ # Task("range exclusive", arrow(tint, tlist(tint)),
210
+ # [((n,), list(range(n - 1)))
211
+ # for n in range(2, 11)]),
212
+ # Task("range length", arrow(tlist(tint),tlist(tint)),
213
+ # [((l,),list(range(len(l))))
214
+ # for _ in range(10)
215
+ # for l in [randomList()] ])
216
+ ]
217
+
218
+ # Encourages learning how to treat a list as an array
219
+ arrayBootstrap = [
220
+ Task("index int", arrow(tint, tlist(tint), tint),
221
+ [((n, l), l[n])
222
+ for n in range(10)
223
+ for l in [[randint(0, 9) for _ in range(randint(n + 1, n + 5))]]]),
224
+ # Task("last n", arrow(tint, tlist(tint), tlist(tint)),
225
+ # [((n, l), l[-n:])
226
+ # for n in range(10)
227
+ # for l in [[randint(0, 9) for _ in range(randint(n + 1, n + 5))]]]),
228
+ Task("1-index int", arrow(tint, tlist(tint), tint),
229
+ [((n, l), l[n - 1])
230
+ for n in range(1,11)
231
+ for l in [[randint(0, 9) for _ in range(randint(n + 1, n + 4))]]])
232
+
233
+ # Task("index bool", arrow(tint, tlist(tbool), tbool),
234
+ # [((n, l), l[n])
235
+ # for n in range(10)
236
+ # for l in [[flip() for _ in range(randint(n + 1, n + 5))]]])
237
+ ]
238
+
239
+ # Teaches how to slice lists, not sure if we really need this though
240
+ sliceBootstrap = [
241
+ Task("take bool", arrow(tint, tlist(tbool), tlist(tbool)),
242
+ [((n, l), l[:n])
243
+ for n in range(10)
244
+ for l in [[flip() for _ in range(randint(n, n + 5))]]]),
245
+ Task("drop bool", arrow(tint, tlist(tbool), tlist(tbool)),
246
+ [((n, l), l[n:])
247
+ for n in range(10)
248
+ for l in [[flip() for _ in range(randint(n, n + 5))]]]),
249
+
250
+ Task("take int", arrow(tint, tlist(tint), tlist(tint)),
251
+ [((n, l), l[:n])
252
+ for n in range(10)
253
+ for l in [[randint(0, 9) for _ in range(randint(n, n + 5))]]]),
254
+ Task("drop int", arrow(tint, tlist(tint), tlist(tint)),
255
+ [((n, l), l[n:])
256
+ for n in range(10)
257
+ for l in [[randint(0, 9) for _ in range(randint(n, n + 5))]]]),
258
+
259
+ ]
260
+
261
+ # learning to fold
262
+ foldBootstrap = [
263
+ Task("stutter", arrow(tlist(tint),tlist(tint)),
264
+ [((l,), [z for x in l for z in [x,x] ])
265
+ for _ in range(10)
266
+ for l in [randomList()] ]),
267
+ Task("sum", arrow(tlist(tint), tint),
268
+ [((l,), sum(l))
269
+ for _ in range(10)
270
+ for l in [randomList()]]),
271
+ # Task("difference", arrow(tlist(tint), tint),
272
+ # [((l,), reduce(lambda x, y: y - x, reversed(l), 1))
273
+ # for _ in range(10)
274
+ # for l in [randomList()[:4]]]),
275
+ # Task("append bool", arrow(tlist(tbool), tlist(tbool), tlist(tbool)),
276
+ # [((x, y), x + y)
277
+ # for _ in range(10)
278
+ # for [x, y] in [[randomBooleanList(), randomBooleanList()]]]),
279
+ Task("append constant 0", arrow(tlist(tint),tlist(tint)),
280
+ [((l,),l + [0])
281
+ for _ in range(10)
282
+ for l in [randomList()] ]),
283
+ ]
284
+
285
+ # learning to map
286
+ mapBootstrap = [
287
+ Task("map double", arrow(tlist(tint), tlist(tint)),
288
+ [((l,), list(map(lambda n: n * 2, l)))
289
+ for _ in range(10)
290
+ for l in [randomList()]]),
291
+ Task("map increment", arrow(tlist(tint),tlist(tint)),
292
+ [((l,),list(map(lambda n: n+1, l)))
293
+ for _ in range(10)
294
+ for l in [randomList()] ]),
295
+ Task("map negation", arrow(tlist(tint),tlist(tint)),
296
+ [((l,),list(map(lambda n: 0-n, l)))
297
+ for _ in range(10)
298
+ for l in [randomList()] ]),
299
+ # Task("map car", arrow(tlist(tlist(tint)), tlist(tint)),
300
+ # [((l,), [n[0] for n in l])
301
+ # for _ in4 range(10)
302
+ # for l in [randomListOfLists()]]),
303
+ # Task("map cdr", arrow(tlist(tlist(tbool)),tlist(tlist(tbool))),
304
+ # [((l,),map(lambda n: n[1:],l))
305
+ # for _ in range(10)
306
+ # for l in [randomListOfLists_bool()]]),
307
+ # Task("map empty?", arrow(tlist(tlist(tint)), tlist(tboolean)),
308
+ # [((l,), [n == [] for n in l])
309
+ # for _ in range(10)
310
+ # for l in [[[] if flip() else randomList() for _ in range(randint(1, 5))]]]),
311
+
312
+ # Task("map eq 0?", arrow(tlist(tint),tlist(tboolean)),
313
+ # [((l,),map(lambda n: 0 == n,l))
314
+ # for _ in range(10)
315
+ # for l in [[ randint(0,3) for _ in range(randint(4,7)) ]] ])
316
+
317
+ ]
318
+ difficultMaps = [
319
+ Task("map quadruple", arrow(tlist(tint), tlist(tint)),
320
+ [((l,), list(map(lambda n: n * 4, l)))
321
+ for _ in range(10)
322
+ for l in [randomList()]]),
323
+ Task("map add 3", arrow(tlist(tint),tlist(tint)),
324
+ [((l,),list(map(lambda n: n+3, l)))
325
+ for _ in range(10)
326
+ for l in [randomList()] ]),
327
+
328
+ ]
329
+
330
+ # Learning to zip lists together
331
+ zipBootstrap = [
332
+ Task("zip plus", arrow(tlist(tint),tlist(tint),tlist(tint)),
333
+ [((l1,l2),list(map(lambda x,y: x+y,l1,l2)))
334
+ for _ in range(10)
335
+ for l1 in [randomList(minimumLength=2, maximumLength=4)]
336
+ for l2 in [[ randint(0,9) for _ in range(len(l1)) ]]]),
337
+ Task("zip minus", arrow(tlist(tint),tlist(tint),tlist(tint)),
338
+ [((l1,l2),list(map(lambda x,y: x-y,l1,l2)))
339
+ for _ in range(10)
340
+ for l1 in [randomList(minimumLength=2, maximumLength=4)]
341
+ for l2 in [[ randint(0,9) for _ in range(len(l1)) ]]]),
342
+ # Task("zip eq?", arrow(tlist(tint), tlist(tint), tlist(tbool)),
343
+ # [((l1, l2), list(map(lambda x, y: x == y, l1, l2)))
344
+ # for _ in range(10)
345
+ # for l1 in [[randint(0, 3) for _ in range(randint(4, 7))]]
346
+ # for l2 in [[randint(0, 3) for _ in range(len(l1))]]]),
347
+ # Task("zip cons", arrow(tlist(tbool), tlist(tlist(tbool)), tlist(tlist(tbool))),
348
+ # [((l1, l2), list(map(lambda x, y: [x] + y, l1, l2)))
349
+ # for _ in range(10)
350
+ # for l1 in [randomBooleanList()]
351
+ # for l2 in [randomListOfLists_bool(l=len(l1))]]),
352
+ # Task("zip cons", arrow(tlist(tint),tlist(tlist(tint)),tlist(tlist(tint))),
353
+ # [((l1,l2),list(map(lambda x,y: [x]+y,l1,l2)))
354
+ # for _ in range(10)
355
+ # for l1 in [randomList()]
356
+ # for l2 in [[ randomList() for _ in range(len(l1)) ]]]),
357
+ ]
358
+
359
+ # Learning to filter
360
+ filterBootstrap = [
361
+ # Task("remove empty lists",
362
+ # arrow(tlist(tlist(tbool)), tlist(tlist(tbool))),
363
+ # [((ls,), [l for l in ls if len(l) > 0])
364
+ # for _ in range(10)
365
+ # for ls in [[[flip() for _ in range(randint(0, 3))]
366
+ # for _ in range(4)]]])
367
+ # Task("remove non 0s",
368
+ # arrow(tlist(tint), tlist(tint)),
369
+ # [((xs,), filter(lambda x: x == 0, xs))
370
+ # for _ in range(10)
371
+ # for xs in [[ randint(0,3) for _ in range(5) ]] ]),
372
+ Task("remove 0s",
373
+ arrow(tlist(tint), tlist(tint)),
374
+ [((xs,), [x for x in xs if x != 0])
375
+ for _ in range(10)
376
+ for xs in [[randint(0, 3) for _ in range(5)]]]),
377
+ Task("remove non-positives",
378
+ arrow(tlist(tint), tlist(tint)),
379
+ [((xs,), [x for x in xs if not (x > 1)])
380
+ for _ in range(10)
381
+ for xs in [[randint(0, 3) for _ in range(5)]]]),
382
+ ]
383
+
384
+ return lengthBootstrap + filterBootstrap + \
385
+ unfoldBootstrap + arrayBootstrap + foldBootstrap + mapBootstrap + zipBootstrap
386
+
387
+
388
+ def bonusListProblems():
389
+ # Taken from https://www.ijcai.org/Proceedings/75/Papers/037.pdf
390
+ # These problems might be a lot easier if we do not use numbers
391
+ def randomList(lb=None, ub=None):
392
+ if lb is None:
393
+ lb = 2
394
+ if ub is None:
395
+ ub = 5
396
+ return [randint(0, 5) for _ in range(randint(lb, ub))]
397
+
398
+ bonus = [
399
+ Task(
400
+ "pair reverse", arrow(tlist(tint), tlist(tint)),
401
+ [((x,), [x[j + (1 if j % 2 == 0 else -1)]
402
+ for j in range(len(x))])
403
+ for _ in range(5)
404
+ for x in [randomList(10, 10)]]
405
+ ),
406
+ Task(
407
+ "duplicate each element", arrow(tlist(tint), tlist(tint)),
408
+ [((x,), [a for z in x for a in [z] * 2])
409
+ for _ in range(5)
410
+ for x in [randomList(4, 6)]]
411
+ ),
412
+ Task(
413
+ "reverse duplicate each element", arrow(tlist(tint), tlist(tint)),
414
+ [((x,), [a for z in reversed(x) for a in [z] * 2])]
415
+ ),
416
+ ]
417
+ return bonus
418
+
419
+ def sortBootstrap():
420
+ # These tasks have as their goal the learning of (1) filter, and
421
+ # (2) sort, which uses filter.
422
+ def flip(): return random() > 0.5
423
+ def randomList(lb=None, ub=None):
424
+ if lb is None:
425
+ lb = 2
426
+ if ub is None:
427
+ ub = 5
428
+ return [randint(0, 10) for _ in range(randint(lb, ub))]
429
+ def randomBooleanList():
430
+ return [flip() for _ in range(randint(4, 7))]
431
+ def removeDuplicates(l):
432
+ if len(l) == 0: return l
433
+ return [l[0]] + removeDuplicates([ z for z in l if z != l[0] ])
434
+
435
+ filterBootstrap = [
436
+ # Task("remove empty lists",
437
+ # arrow(tlist(tlist(tbool)), tlist(tlist(tbool))),
438
+ # [((ls,), [l for l in ls if len(l) > 0])
439
+ # for _ in range(10)
440
+ # for ls in [[[flip() for _ in range(randint(0, 3))]
441
+ # for _ in range(4)]]]),
442
+ # Task("remove non 0s",
443
+ # arrow(tlist(tint), tlist(tint)),
444
+ # [((xs,), filter(lambda x: x == 0, xs))
445
+ # for _ in range(10)
446
+ # for xs in [[ randint(0,3) for _ in range(5) ]] ]),
447
+ Task("remove 0s",
448
+ arrow(tlist(tint), tlist(tint)),
449
+ [((xs,), [x for x in xs if x != 0])
450
+ for _ in range(10)
451
+ for xs in [[randint(0, 3) for _ in range(5)]]]),
452
+ # Task("remove primes",
453
+ # arrow(tlist(tint), tlist(tint)),
454
+ # [((xs,), [x for x in xs if not (x in {2,3,5,7,11,13,17,19,23})])
455
+ # for _ in range(10)
456
+ # for xs in [[randint(0, 20) for _ in range(7)]]]),
457
+ Task("remove squares",
458
+ arrow(tlist(tint), tlist(tint)),
459
+ [((xs,), [x for x in xs if not (int(x**0.5)**2 == x)])
460
+ for _ in range(10)
461
+ for xs in [[randint(0, 20) for _ in range(7)]]]),
462
+ Task("remove > 1",
463
+ arrow(tlist(tint), tlist(tint)),
464
+ [((xs,), [x for x in xs if not (x > 1)])
465
+ for _ in range(10)
466
+ for xs in [[randint(0, 5) for _ in range(7)]]]),
467
+ ]
468
+
469
+ # Needed for selection sort
470
+ minimumBootstrap = [
471
+ Task("min2", arrow(tint,tint,tint),
472
+ [((x,y),min(x,y))
473
+ for x in range(4)
474
+ for y in range(4) ]),
475
+ Task("minimum of list", arrow(tlist(tint),tint),
476
+ [((l,),min(l))
477
+ for _ in range(15)
478
+ for l in [randomList()] ])
479
+ ]
480
+
481
+ appendBootstrap = [
482
+ Task("append bool", arrow(tlist(tbool), tlist(tbool), tlist(tbool)),
483
+ [((x, y), x + y)
484
+ for _ in range(10)
485
+ for [x, y] in [[randomBooleanList(), randomBooleanList()]]]),
486
+ Task("append int", arrow(tlist(tint), tlist(tint), tlist(tint)),
487
+ [((x, y), x + y)
488
+ for _ in range(10)
489
+ for [x, y] in [[randomList(), randomList()]]])
490
+ ]
491
+
492
+ insertionBootstrap = [
493
+ Task("filter greater than or equal", arrow(tint,tlist(tint),tlist(tint)),
494
+ [((x,l), [y for y in l if y >= x ])
495
+ for _ in range(15)
496
+ for x in [randint(0,5)]
497
+ for l in [randomList()] ]),
498
+ Task("filter less than", arrow(tint,tlist(tint),tlist(tint)),
499
+ [((x,l), [y for y in l if y < x ])
500
+ for _ in range(15)
501
+ for x in [randint(0,5)]
502
+ for l in [randomList()] ]),
503
+ Task("insert into sorted list (I)", arrow(tint,tlist(tint),tlist(tint)),
504
+ [((x,l), [y for y in l if y < x ] + [x] + [y for y in l if y >= x ])
505
+ for _ in range(15)
506
+ for x in [randint(0,5)]
507
+ for _l in [randomList()]
508
+ for l in [sorted(_l)] ]),
509
+ Task("insert into sorted list (II)", arrow(tint,tlist(tint),tlist(tint)),
510
+ [((x,l), [y for y in l if y < x ] + [x] + [y for y in l if y >= x ])
511
+ for _ in range(15)
512
+ for x in [randint(0,5)]
513
+ for l in [randomList()] ])
514
+ ]
515
+
516
+
517
+ sortTask = [
518
+ Task("sort-and-deduplicate", arrow(tlist(tint),tlist(tint)),
519
+ [((l,),list(sorted(l)))
520
+ for _ in range(15)
521
+ for l in [removeDuplicates(randomList())]
522
+ ])]
523
+
524
+ slowSort = [
525
+ Task("+1 maximum list", arrow(tlist(tint), tint),
526
+ [((l,),max(l) + 1)
527
+ for _ in range(15)
528
+ for l in [randomList()] ]),
529
+ Task("range +1 maximum list", arrow(tlist(tint), tlist(tint)),
530
+ [((l,),list(range(max(l) + 1)))
531
+ for _ in range(15)
532
+ for l in [randomList()] ]),
533
+ ]
534
+
535
+
536
+ tasks = sortTask + slowSort
537
+ for t in tasks: t.mustTrain = True
538
+ return tasks
539
+
540
+
541
+ def exportTasks():
542
+ import sys
543
+ import pickle as pickle
544
+
545
+ n_examples = 15
546
+ if len(sys.argv) > 1:
547
+ n_examples = int(sys.argv[1])
548
+
549
+ eprint("Downloading and generating dataset")
550
+ tasks = sorted(make_list_tasks(n_examples), key=lambda t: t.name)
551
+ eprint("Got {} list tasks".format(len(tasks)))
552
+
553
+ with open("data/list_tasks.pkl", "w") as f:
554
+ pickle.dump(tasks, f)
555
+ eprint("Wrote list tasks to data/list_tasks.pkl")
556
+
557
+
558
+ if __name__ == "__main__":
559
+ import json
560
+ def retrieveJSONTasks(filename, features=False):
561
+ """
562
+ For JSON of the form:
563
+ {"name": str,
564
+ "type": {"input" : bool|int|list-of-bool|list-of-int,
565
+ "output": bool|int|list-of-bool|list-of-int},
566
+ "examples": [{"i": data, "o": data}]}
567
+ """
568
+ with open(filename, "r") as f:
569
+ loaded = json.load(f)
570
+ TP = {
571
+ "bool": tbool,
572
+ "int": tint,
573
+ "list-of-bool": tlist(tbool),
574
+ "list-of-int": tlist(tint),
575
+ }
576
+ return [Task(
577
+ item["name"],
578
+ arrow(TP[item["type"]["input"]], TP[item["type"]["output"]]),
579
+ [((ex["i"],), ex["o"]) for ex in item["examples"]],
580
+ features=(None if not features else list_features(
581
+ [((ex["i"],), ex["o"]) for ex in item["examples"]])),
582
+ cache=False,
583
+ ) for item in loaded]
584
+ for t in retrieveJSONTasks("data/list_tasks.json") + sortBootstrap() + make_list_bootstrap_tasks():
585
+ print(t.describe())
586
+ print()
587
+ # exportTasks()
dreamcoder/domains/logo/__init__.py ADDED
File without changes
dreamcoder/domains/logo/logoPrimitives.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.program import Primitive, Program
2
+ from dreamcoder.type import arrow, baseType, tint
3
+
4
+ turtle = baseType("turtle")
5
+ tstate = baseType("tstate")
6
+ tangle = baseType("tangle")
7
+ tlength = baseType("tlength")
8
+
9
+ primitives = [
10
+ Primitive("logo_UA", tangle, ""),
11
+ Primitive("logo_UL", tlength, ""),
12
+
13
+ Primitive("logo_ZA", tangle, ""),
14
+ Primitive("logo_ZL", tlength, ""),
15
+
16
+ Primitive("logo_DIVA", arrow(tangle,tint,tangle), ""),
17
+ Primitive("logo_MULA", arrow(tangle,tint,tangle), ""),
18
+ Primitive("logo_DIVL", arrow(tlength,tint,tlength), ""),
19
+ Primitive("logo_MULL", arrow(tlength,tint,tlength), ""),
20
+
21
+ Primitive("logo_ADDA", arrow(tangle,tangle,tangle), ""),
22
+ Primitive("logo_SUBA", arrow(tangle,tangle,tangle), ""),
23
+ # Primitive("logo_ADDL", arrow(tlength,tlength,tlength), ""),
24
+ # Primitive("logo_SUBL", arrow(tlength,tlength,tlength), ""),
25
+
26
+ # Primitive("logo_PU", arrow(turtle,turtle), ""),
27
+ # Primitive("logo_PD", arrow(turtle,turtle), ""),
28
+ Primitive("logo_PT", arrow(arrow(turtle,turtle),arrow(turtle,turtle)), None),
29
+ Primitive("logo_FWRT", arrow(tlength,tangle,turtle,turtle), ""),
30
+ Primitive("logo_GETSET", arrow(arrow(turtle,turtle),turtle,turtle), "")
31
+ ] + [
32
+ Primitive("logo_IFTY", tint, ""),
33
+ Primitive("logo_epsA", tangle, ""),
34
+ Primitive("logo_epsL", tlength, ""),
35
+ Primitive("logo_forLoop", arrow(tint, arrow(tint, turtle, turtle), turtle, turtle), "ERROR: python has no way of expressing this hence you shouldn't eval on this"),
36
+ ] + [Primitive(str(j), tint, j) for j in range(10)]
37
+
38
+ if __name__ == "__main__":
39
+ expr_s = "(lambda (logo_forLoop 3 (lambda (lambda (logo_GET (lambda (logo_FWRT (logo_S2L (logo_I2S 1)) (logo_S2A (logo_I2S 0)) (logo_SET $0 (logo_FWRT (logo_S2L eps) (logo_DIVA (logo_S2A (logo_I2S 2)) (logo_I2S 3)) ($1)))))))) ($0)))"
40
+ x = Program.parse(expr_s)
41
+ print(x)
dreamcoder/domains/logo/main.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import datetime
3
+ import json
4
+ import os
5
+ import pickle
6
+ import random as random
7
+ import subprocess
8
+ import sys
9
+ import time
10
+
11
+ try:
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ except:
17
+ print("WARNING: Could not import torch. This is only okay when doing pypy compression.",
18
+ file=sys.stderr)
19
+
20
+ from dreamcoder.domains.logo.makeLogoTasks import makeTasks, montageTasks, drawLogo
21
+ from dreamcoder.domains.logo.logoPrimitives import primitives, turtle, tangle, tlength
22
+ from dreamcoder.dreamcoder import ecIterator
23
+ from dreamcoder.grammar import Grammar
24
+ from dreamcoder.program import Program
25
+ try:
26
+ from dreamcoder.recognition import variable, maybe_cuda
27
+ except:
28
+ print("WARNING: Could not import recognition. This is only okay when doing pypy compression.",
29
+ file=sys.stderr)
30
+ from dreamcoder.task import Task
31
+ from dreamcoder.type import arrow
32
+ from dreamcoder.utilities import eprint, testTrainSplit, loadPickle
33
+
34
+
35
+ def animateSolutions(allFrontiers):
36
+ programs = []
37
+ filenames = []
38
+ for n,(t,f) in enumerate(allFrontiers.items()):
39
+ if f.empty: continue
40
+
41
+ programs.append(f.bestPosterior.program)
42
+ filenames.append(f"/tmp/logo_animation_{n}")
43
+
44
+ drawLogo(*programs, pretty=True, smoothPretty=True, resolution=128, animate=True,
45
+ filenames=filenames)
46
+
47
+
48
+
49
+ def dreamFromGrammar(g, directory, N=100):
50
+ if isinstance(g,Grammar):
51
+ programs = [ p
52
+ for _ in range(N)
53
+ for p in [g.sample(arrow(turtle,turtle),
54
+ maximumDepth=20)]
55
+ if p is not None]
56
+ else:
57
+ programs = g
58
+ drawLogo(*programs,
59
+ pretty=False, smoothPretty=False,
60
+ resolution=512,
61
+ filenames=[f"{directory}/{n}.png" for n in range(len(programs)) ],
62
+ timeout=1)
63
+ drawLogo(*programs,
64
+ pretty=True, smoothPretty=False,
65
+ resolution=512,
66
+ filenames=[f"{directory}/{n}_pretty.png" for n in range(len(programs)) ],
67
+ timeout=1)
68
+ drawLogo(*programs,
69
+ pretty=False, smoothPretty=True,
70
+ resolution=512,
71
+ filenames=[f"{directory}/{n}_smooth_pretty.png" for n in range(len(programs)) ],
72
+ timeout=1)
73
+ for n,p in enumerate(programs):
74
+ with open(f"{directory}/{n}.dream","w") as handle:
75
+ handle.write(str(p))
76
+
77
+ try:
78
+ class Flatten(nn.Module):
79
+ def __init__(self):
80
+ super(Flatten, self).__init__()
81
+
82
+ def forward(self, x):
83
+ return x.view(x.size(0), -1)
84
+
85
+
86
+ class LogoFeatureCNN(nn.Module):
87
+ special = "LOGO"
88
+
89
+ def __init__(self, tasks, testingTasks=[], cuda=False, H=64):
90
+ super(LogoFeatureCNN, self).__init__()
91
+
92
+ self.sub = prefix_dreams + str(int(time.time()))
93
+
94
+ self.recomputeTasks = False
95
+
96
+ def conv_block(in_channels, out_channels, p=True):
97
+ return nn.Sequential(
98
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
99
+ # nn.BatchNorm2d(out_channels),
100
+ nn.ReLU(),
101
+ # nn.Conv2d(out_channels, out_channels, 3, padding=1),
102
+ # nn.ReLU(),
103
+ nn.MaxPool2d(2))
104
+
105
+ self.inputImageDimension = 128
106
+ self.resizedDimension = 128
107
+ assert self.inputImageDimension % self.resizedDimension == 0
108
+
109
+ # channels for hidden
110
+ hid_dim = 64
111
+ z_dim = 64
112
+
113
+ self.encoder = nn.Sequential(
114
+ conv_block(1, hid_dim),
115
+ conv_block(hid_dim, hid_dim),
116
+ conv_block(hid_dim, hid_dim),
117
+ conv_block(hid_dim, hid_dim),
118
+ conv_block(hid_dim, hid_dim),
119
+ conv_block(hid_dim, z_dim),
120
+ Flatten()
121
+ )
122
+
123
+ self.outputDimensionality = 256
124
+
125
+
126
+
127
+
128
+ def forward(self, v):
129
+ assert len(v) == self.inputImageDimension*self.inputImageDimension
130
+ floatOnlyTask = list(map(float, v))
131
+ reshaped = [floatOnlyTask[i:i + self.inputImageDimension]
132
+ for i in range(0, len(floatOnlyTask), self.inputImageDimension)]
133
+ v = variable(reshaped).float()
134
+ # insert channel and batch
135
+ v = torch.unsqueeze(v, 0)
136
+ v = torch.unsqueeze(v, 0)
137
+ v = maybe_cuda(v, next(self.parameters()).is_cuda)/256.
138
+ window = int(self.inputImageDimension/self.resizedDimension)
139
+ v = F.avg_pool2d(v, (window,window))
140
+ v = self.encoder(v)
141
+ return v.view(-1)
142
+
143
+ def featuresOfTask(self, t): # Take a task and returns [features]
144
+ return self(t.highresolution)
145
+
146
+ def tasksOfPrograms(self, ps, types):
147
+ images = drawLogo(*ps, resolution=128)
148
+ if len(ps) == 1: images = [images]
149
+ tasks = []
150
+ for i in images:
151
+ if isinstance(i, str): tasks.append(None)
152
+ else:
153
+ t = Task("Helm", arrow(turtle,turtle), [])
154
+ t.highresolution = i
155
+ tasks.append(t)
156
+ return tasks
157
+
158
+ def taskOfProgram(self, p, t):
159
+ return self.tasksOfPrograms([p], None)[0]
160
+ except:
161
+ pass
162
+
163
+ def list_options(parser):
164
+ parser.add_argument("--proto",
165
+ default=False,
166
+ action="store_true",
167
+ help="Should we use prototypical networks?")
168
+ parser.add_argument("--target", type=str,
169
+ default=[],
170
+ action='append',
171
+ help="Which tasks should this try to solve")
172
+ parser.add_argument("--reduce", type=str,
173
+ default=[],
174
+ action='append',
175
+ help="Which tasks should this try to solve")
176
+ parser.add_argument("--save", type=str,
177
+ default=None,
178
+ help="Filepath output the grammar if this is a child")
179
+ parser.add_argument("--prefix", type=str,
180
+ default="experimentOutputs/",
181
+ help="Filepath output the grammar if this is a child")
182
+ parser.add_argument("--dreamCheckpoint", type=str,
183
+ default=None,
184
+ help="File to load in order to get dreams")
185
+ parser.add_argument("--dreamDirectory", type=str,
186
+ default=None,
187
+ help="Directory in which to dream from --dreamCheckpoint")
188
+ parser.add_argument("--visualize",
189
+ default=None, type=str)
190
+ parser.add_argument("--cost", default=False, action='store_true',
191
+ help="Impose a smooth cost on using ink")
192
+ parser.add_argument("--split",
193
+ default=1., type=float)
194
+ parser.add_argument("--animate",
195
+ default=None, type=str)
196
+
197
+
198
+
199
+ def outputDreams(checkpoint, directory):
200
+ from dreamcoder.utilities import loadPickle
201
+ result = loadPickle(checkpoint)
202
+ eprint(" [+] Loaded checkpoint",checkpoint)
203
+ g = result.grammars[-1]
204
+ if directory is None:
205
+ randomStr = ''.join(random.choice('0123456789') for _ in range(10))
206
+ directory = "/tmp/" + randomStr
207
+ eprint(" Dreaming into",directory)
208
+ os.system("mkdir -p %s"%directory)
209
+ dreamFromGrammar(g, directory)
210
+
211
+ def enumerateDreams(checkpoint, directory):
212
+ from dreamcoder.dreaming import backgroundHelmholtzEnumeration
213
+ from dreamcoder.utilities import loadPickle
214
+ result = loadPickle(checkpoint)
215
+ eprint(" [+] Loaded checkpoint",checkpoint)
216
+ g = result.grammars[-1]
217
+ if directory is None: assert False, "please specify a directory"
218
+ eprint(" Dreaming into",directory)
219
+ os.system("mkdir -p %s"%directory)
220
+ frontiers = backgroundHelmholtzEnumeration(makeTasks(None,None), g, 100,
221
+ evaluationTimeout=0.01,
222
+ special=LogoFeatureCNN.special)()
223
+ print(f"{len(frontiers)} total frontiers.")
224
+ MDL = 0
225
+ def L(f):
226
+ return -list(f.entries)[0].logPrior
227
+ frontiers.sort(key=lambda f: -L(f))
228
+ while len(frontiers) > 0:
229
+ # get frontiers whose MDL is between [MDL,MDL + 1)
230
+ fs = []
231
+ while len(frontiers) > 0 and L(frontiers[-1]) < MDL + 1:
232
+ fs.append(frontiers.pop(len(frontiers) - 1))
233
+ if fs:
234
+ random.shuffle(fs)
235
+ print(f"{len(fs)} programs with MDL between [{MDL}, {MDL + 1})")
236
+
237
+ fs = fs[:500]
238
+ os.system(f"mkdir {directory}/{MDL}")
239
+ dreamFromGrammar([list(f.entries)[0].program for f in fs],
240
+ f"{directory}/{MDL}")
241
+ MDL += 1
242
+
243
+ def visualizePrimitives(primitives, export='/tmp/logo_primitives.png'):
244
+ from itertools import product
245
+ from dreamcoder.program import Index,Abstraction,Application
246
+ from dreamcoder.utilities import montageMatrix,makeNiceArray
247
+ from dreamcoder.type import tint
248
+ import scipy.misc
249
+ from dreamcoder.domains.logo.makeLogoTasks import parseLogo
250
+
251
+ angles = [Program.parse(a)
252
+ for a in ["logo_ZA",
253
+ "logo_epsA",
254
+ "(logo_MULA logo_epsA 2)",
255
+ "(logo_DIVA logo_UA 4)",
256
+ "(logo_DIVA logo_UA 5)",
257
+ "(logo_DIVA logo_UA 7)",
258
+ "(logo_DIVA logo_UA 9)",
259
+ ] ]
260
+ specialAngles = {"#(lambda (lambda (logo_forLoop logo_IFTY (lambda (lambda (logo_FWRT (logo_MULL logo_UL 3) (logo_MULA $2 4) $0))) $1)))":
261
+ [Program.parse("(logo_MULA logo_epsA 4)")]+[Program.parse("(logo_DIVA logo_UA %d)"%n) for n in [7,9] ]}
262
+ numbers = [Program.parse(n)
263
+ for n in ["1","2","5","7","logo_IFTY"] ]
264
+ specialNumbers = {"#(lambda (#(lambda (lambda (lambda (lambda (logo_forLoop $2 (lambda (lambda (logo_FWRT $5 (logo_DIVA logo_UA $3) $0))) $0))))) (logo_MULL logo_UL $0) 4 4))":
265
+ [Program.parse(str(n)) for n in [1,2,3] ]}
266
+ distances = [Program.parse(l)
267
+ for l in ["logo_ZL",
268
+ "logo_epsL",
269
+ "(logo_MULL logo_epsL 2)",
270
+ "(logo_DIVL logo_UL 2)",
271
+ "logo_UL"] ]
272
+ subprograms = [parseLogo(sp)
273
+ for sp in ["(move 1d 0a)",
274
+ "(loop i infinity (move (*l epsilonLength 4) (*a epsilonAngle 2)))",
275
+ "(loop i infinity (move (*l epsilonLength 5) (/a epsilonAngle 2)))",
276
+ "(loop i 4 (move 1d (/a 1a 4)))"]]
277
+
278
+ entireArguments = {"#(lambda (lambda (#(#(lambda (lambda (lambda (logo_forLoop $2 (lambda (lambda (logo_FWRT $2 $3 $0))))))) logo_IFTY) (logo_MULA (#(logo_DIVA logo_UA) $1) $0) (#(logo_MULL logo_UL) 3))))":
279
+ [[Program.parse(str(x)) for x in xs ]
280
+ for xs in [("3", "1", "$0"),
281
+ ("4", "1", "$0"),
282
+ ("5", "1", "$0"),
283
+ ("5", "3", "$0"),
284
+ ("7", "3", "$0")]]}
285
+ specialDistances = {"#(lambda (lambda (logo_forLoop 7 (lambda (lambda (#(lambda (lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $2 (lambda (lambda (logo_FWRT $2 $3 $0))))))) 7 $1 $2 $0)))) $3 logo_epsA $0))) $0)))":
286
+ [Program.parse("(logo_MULL logo_epsL %d)"%n) for n in range(5)]}
287
+
288
+ matrix = []
289
+ for p in primitives:
290
+ if not p.isInvented: continue
291
+ t = p.tp
292
+ eprint(p,":",p.tp)
293
+ if t.returns() != turtle:
294
+ eprint("\t(does not return a turtle)")
295
+ continue
296
+
297
+ def argumentChoices(t):
298
+ if t == turtle:
299
+ return [Index(0)]
300
+ elif t == arrow(turtle,turtle):
301
+ return subprograms
302
+ elif t == tint:
303
+ return specialNumbers.get(str(p),numbers)
304
+ elif t == tangle:
305
+ return specialAngles.get(str(p),angles)
306
+ elif t == tlength:
307
+ return specialDistances.get(str(p),distances)
308
+ else: return []
309
+
310
+ ts = []
311
+ for arguments in entireArguments.get(str(p),product(*[argumentChoices(t) for t in t.functionArguments() ])):
312
+ eprint(arguments)
313
+ pp = p
314
+ for a in arguments: pp = Application(pp,a)
315
+ pp = Abstraction(pp)
316
+ i = np.reshape(np.array(drawLogo(pp, resolution=128)), (128,128))
317
+ if i is not None:
318
+ ts.append(i)
319
+
320
+
321
+ if ts == []: continue
322
+
323
+ matrix.append(ts)
324
+ if len(ts) < 6: ts = [ts]
325
+ else: ts = makeNiceArray(ts)
326
+ r = montageMatrix(ts)
327
+ fn = "/tmp/logo_primitive_%d.png"%len(matrix)
328
+ eprint("\tExported to",fn)
329
+ scipy.misc.imsave(fn, r)
330
+
331
+ matrix = montageMatrix(matrix)
332
+ scipy.misc.imsave(export, matrix)
333
+
334
+
335
+ def main(args):
336
+ """
337
+ Takes the return value of the `commandlineArguments()` function as input and
338
+ trains/tests the model on LOGO tasks.
339
+ """
340
+
341
+ # The below legacy global statement is required since prefix_dreams is used by LogoFeatureCNN.
342
+ # TODO(lcary): use argument passing instead of global variables.
343
+ global prefix_dreams
344
+
345
+ # The below global statement is required since primitives is modified within main().
346
+ # TODO(lcary): use a function call to retrieve and declare primitives instead.
347
+ global primitives
348
+
349
+ visualizeCheckpoint = args.pop("visualize")
350
+ if visualizeCheckpoint is not None:
351
+ with open(visualizeCheckpoint,'rb') as handle:
352
+ primitives = pickle.load(handle).grammars[-1].primitives
353
+ visualizePrimitives(primitives)
354
+ sys.exit(0)
355
+
356
+ dreamCheckpoint = args.pop("dreamCheckpoint")
357
+ dreamDirectory = args.pop("dreamDirectory")
358
+
359
+ proto = args.pop("proto")
360
+
361
+ if dreamCheckpoint is not None:
362
+ #outputDreams(dreamCheckpoint, dreamDirectory)
363
+ enumerateDreams(dreamCheckpoint, dreamDirectory)
364
+ sys.exit(0)
365
+
366
+ animateCheckpoint = args.pop("animate")
367
+ if animateCheckpoint is not None:
368
+ animateSolutions(loadPickle(animateCheckpoint).allFrontiers)
369
+ sys.exit(0)
370
+
371
+ target = args.pop("target")
372
+ red = args.pop("reduce")
373
+ save = args.pop("save")
374
+ prefix = args.pop("prefix")
375
+ prefix_dreams = prefix + "/dreams/" + ('_'.join(target)) + "/"
376
+ prefix_pickles = prefix + "/logo." + ('.'.join(target))
377
+ if not os.path.exists(prefix_dreams):
378
+ os.makedirs(prefix_dreams)
379
+ tasks = makeTasks(target, proto)
380
+ eprint("Generated", len(tasks), "tasks")
381
+
382
+ costMatters = args.pop("cost")
383
+ for t in tasks:
384
+ t.specialTask[1]["costMatters"] = costMatters
385
+ # disgusting hack - include whether cost matters in the dummy input
386
+ if costMatters: t.examples = [(([1]), t.examples[0][1])]
387
+
388
+ os.chdir("prototypical-networks")
389
+ subprocess.Popen(["python","./protonet_server.py"])
390
+ time.sleep(3)
391
+ os.chdir("..")
392
+
393
+
394
+ test, train = testTrainSplit(tasks, args.pop("split"))
395
+ eprint("Split tasks into %d/%d test/train" % (len(test), len(train)))
396
+ try:
397
+ if test: montageTasks(test,"test_")
398
+ montageTasks(train,"train_")
399
+ except:
400
+ eprint("WARNING: couldn't generate montage. Do you have an old version of scipy?")
401
+
402
+ if red is not []:
403
+ for reducing in red:
404
+ try:
405
+ with open(reducing, 'r') as f:
406
+ prods = json.load(f)
407
+ for e in prods:
408
+ e = Program.parse(e)
409
+ if e.isInvented:
410
+ primitives.append(e)
411
+ except EOFError:
412
+ eprint("Couldn't grab frontier from " + reducing)
413
+ except IOError:
414
+ eprint("Couldn't grab frontier from " + reducing)
415
+ except json.decoder.JSONDecodeError:
416
+ eprint("Couldn't grab frontier from " + reducing)
417
+
418
+ primitives = list(OrderedDict((x, True) for x in primitives).keys())
419
+ baseGrammar = Grammar.uniform(primitives, continuationType=turtle)
420
+
421
+ eprint(baseGrammar)
422
+
423
+ timestamp = datetime.datetime.now().isoformat()
424
+ outputDirectory = "experimentOutputs/logo/%s"%timestamp
425
+ os.system("mkdir -p %s"%outputDirectory)
426
+
427
+
428
+ generator = ecIterator(baseGrammar, train,
429
+ testingTasks=test,
430
+ outputPrefix="%s/logo"%outputDirectory,
431
+ evaluationTimeout=0.01,
432
+ **args)
433
+
434
+ r = None
435
+ for result in generator:
436
+ iteration = len(result.learningCurve)
437
+ dreamDirectory = "%s/dreams_%d"%(outputDirectory, iteration)
438
+ os.system("mkdir -p %s"%dreamDirectory)
439
+ eprint("Dreaming into directory",dreamDirectory)
440
+ dreamFromGrammar(result.grammars[-1],
441
+ dreamDirectory)
442
+ r = result
443
+
444
+ needsExport = [str(z)
445
+ for _, _, z
446
+ in r.grammars[-1].productions
447
+ if z.isInvented]
448
+ if save is not None:
449
+ with open(save, 'w') as f:
450
+ json.dump(needsExport, f)
dreamcoder/domains/logo/makeLogoTasks.py ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf8
2
+
3
+ import os
4
+ import random
5
+ import sys
6
+
7
+ from dreamcoder.domains.logo.logoPrimitives import primitives, turtle
8
+ from dreamcoder.task import Task
9
+ from dreamcoder.program import Abstraction, Application, Index, Program
10
+ from dreamcoder.type import arrow
11
+ from dreamcoder.utilities import eprint, jsonBinaryInvoke, random_seed, montage
12
+ from dreamcoder.grammar import Grammar
13
+
14
+
15
+ def drawLogo(*programs,
16
+ timeout=None,
17
+ resolution=None,
18
+ pretty=False, smoothPretty=False,
19
+ filenames=[],
20
+ animate=False,
21
+ cost=False):
22
+ message = {}
23
+ if pretty: message["pretty"] = pretty
24
+ if smoothPretty: message["smoothPretty"] = smoothPretty
25
+ if timeout: message["timeout"] = timeout
26
+ assert resolution is not None, "resolution not provided in drawLogo"
27
+ if isinstance(resolution, list):
28
+ assert len(resolution) == len(programs), "must provide a resolution for each program"
29
+ elif isinstance(resolution, int):
30
+ resolution = [resolution]*len(programs)
31
+ else: assert False
32
+ jobs = []
33
+ for p, size in zip(programs, resolution):
34
+ entry = {"program": str(p),
35
+ "size": size}
36
+ if animate: entry["animate"] = True
37
+ if len(filenames) > 0:
38
+ entry["export"] = filenames[0]
39
+ filenames = filenames[1:]
40
+ jobs.append(entry)
41
+ message["jobs"] = jobs
42
+ response = jsonBinaryInvoke("./logoDrawString", message)
43
+ if cost:
44
+ # include the cost and return tuples of (pixels, cost)
45
+ response = [programResponse if isinstance(programResponse,str) else (programResponse["pixels"], programResponse["cost"])
46
+ for programResponse in response ]
47
+ else:
48
+ response = [programResponse if isinstance(programResponse,str) else programResponse["pixels"]
49
+ for programResponse in response ]
50
+ if len(programs) == 1:
51
+ return response[0]
52
+ return response
53
+
54
+ def makeTasks(subfolders, proto):
55
+ return manualLogoTasks()
56
+
57
+ def parseLogo(s):
58
+
59
+ _ua = Program.parse("logo_UA")
60
+ _ul = Program.parse("logo_UL")
61
+
62
+ _za = Program.parse("logo_ZA")
63
+ _zl = Program.parse("logo_ZL")
64
+
65
+ _da = Program.parse("logo_DIVA")
66
+ _ma = Program.parse("logo_MULA")
67
+ _dl = Program.parse("logo_DIVL")
68
+ _ml = Program.parse("logo_MULL")
69
+
70
+ _aa = Program.parse("logo_ADDA")
71
+ _sa = Program.parse("logo_SUBA")
72
+ _al = None#Program.parse("logo_ADDL")
73
+ _sl = None#Program.parse("logo_SUBL")
74
+
75
+ _pu = None#Program.parse("logo_PU")
76
+ _pd = None#Program.parse("logo_PD")
77
+ _p = Program.parse("logo_PT")
78
+ _move = Program.parse("logo_FWRT")
79
+ _embed = Program.parse("logo_GETSET")
80
+
81
+ _addition = Program.parse("+")
82
+ _infinity = Program.parse("logo_IFTY")
83
+ _ea = Program.parse("logo_epsA")
84
+ _el = Program.parse("logo_epsL")
85
+ _loop = Program.parse("logo_forLoop")
86
+
87
+ from sexpdata import loads, Symbol
88
+ s = loads(s)
89
+ def command(k, environment, continuation):
90
+ assert isinstance(k,list)
91
+ if k[0] == Symbol("move"):
92
+ return Application(Application(Application(_move,
93
+ expression(k[1],environment)),
94
+ expression(k[2],environment)),
95
+ continuation)
96
+ if k[0] == Symbol("for") or k[0] == Symbol("loop"):
97
+ v = k[1]
98
+ b = expression(k[2], environment)
99
+ newEnvironment = [None, v] + environment
100
+ body = block(k[3:], newEnvironment, Index(0))
101
+ return Application(Application(Application(_loop,b),
102
+ Abstraction(Abstraction(body))),
103
+ continuation)
104
+ if k[0] == Symbol("embed"):
105
+ body = block(k[1:], [None] + environment, Index(0))
106
+ return Application(Application(_embed,Abstraction(body)),continuation)
107
+ if k[0] == Symbol("p"):
108
+ body = block(k[1:], [None] + environment, Index(0))
109
+ return Application(Application(_p,Abstraction(body)),continuation)
110
+
111
+ assert False
112
+ def expression(e, environment):
113
+ for n, v in enumerate(environment):
114
+ if e == v: return Index(n)
115
+
116
+ if isinstance(e,int): return Program.parse(str(e))
117
+
118
+ mapping = {"1a": _ua,
119
+ "1d": _ul, "1l": _ul,
120
+ "0a": _za,
121
+ "0d": _zl, "0l": _zl,
122
+ "/a": _da,
123
+ "/l": _dl, "/d": _dl,
124
+ "*a": _ma,
125
+ "*l": _ml, "*d": _ml,
126
+ "+a": _aa,
127
+ "+d": _al, "+l": _al,
128
+ "-a": _sa,
129
+ "-d": _sl, "-l": _sl,
130
+ "+": _addition,
131
+ "infinity": _infinity,
132
+ "epsilonAngle": _ea,
133
+ "epsilonDistance": _el,
134
+ "epsilonLength": _el}
135
+ if e == float('inf'): return _infinity
136
+ for name, value in mapping.items():
137
+ if e == Symbol(name): return value
138
+
139
+ assert isinstance(e,list), "not a list %s"%e
140
+ for name, value in mapping.items():
141
+ if e[0] == Symbol(name):
142
+ f = value
143
+ for argument in e[1:]:
144
+ f = Application(f, expression(argument, environment))
145
+ return f
146
+ assert False
147
+
148
+ def block(b, environment, continuation):
149
+ if len(b) == 0: return continuation
150
+ return command(b[0], environment, block(b[1:], environment, continuation))
151
+
152
+ try: return Abstraction(command(s, [], Index(0)))
153
+ except: return Abstraction(block(s, [], Index(0)))
154
+
155
+
156
+ def manualLogoTask(name, expression, proto=False, needToTrain=False,
157
+ supervise=False, lambdaCalculus=False):
158
+ p = Program.parse(expression) if lambdaCalculus else parseLogo(expression)
159
+ from dreamcoder.domains.logo.logoPrimitives import primitives
160
+ from dreamcoder.grammar import Grammar
161
+ g = Grammar.uniform(primitives, continuationType=turtle)
162
+ gp = Grammar.uniform(primitives)
163
+ try:
164
+ l = g.logLikelihood(arrow(turtle,turtle),p)
165
+ lp = gp.logLikelihood(arrow(turtle,turtle),p)
166
+ assert l >= lp
167
+ eprint(name,-l,"nats")
168
+
169
+ except: eprint("WARNING: could not calculate likelihood of manual logo",p)
170
+
171
+ attempts = 0
172
+ while True:
173
+ [output, highresolution] = drawLogo(p, p, resolution=[28,128], cost=True)
174
+ if output == "timeout" or highresolution == "timeout":
175
+ attempts += 1
176
+ else:
177
+ break
178
+ if attempts > 0:
179
+ eprint(f"WARNING: Took {attempts} attempts to render task {name} within timeout")
180
+
181
+ cost = output[1]
182
+ output = output[0]
183
+ assert highresolution[1] == cost
184
+ highresolution = highresolution[0]
185
+
186
+ shape = list(map(int, output))
187
+ highresolution = list(map(float, highresolution))
188
+ t = Task(name, arrow(turtle,turtle),
189
+ [(([0]), shape)])
190
+ t.mustTrain = needToTrain
191
+ t.proto = proto
192
+ t.specialTask = ("LOGO", {"proto": proto})
193
+ t.specialTask[1]["cost"] = cost*1.05
194
+
195
+ t.highresolution = highresolution
196
+
197
+ if supervise:
198
+ t.supervisedSolution = p
199
+
200
+ return t
201
+
202
+ def dSLDemo():
203
+ n = 0
204
+ demos = []
205
+ def T(source):
206
+ demos.append(manualLogoTask(str(len(demos)), source,
207
+ lambdaCalculus="lambda" in source))
208
+ # this looks like polygons - verify and include
209
+ T("(#(lambda (lambda (#(lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $0 (lambda (lambda (logo_FWRT $4 $3 $0))))))) $1 $0 logo_IFTY))) $1 (logo_DIVA logo_UA $0)))) (logo_MULL logo_UL 4) 3)")
210
+ T("(#(lambda (lambda (#(lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $0 (lambda (lambda (logo_FWRT $4 $3 $0))))))) $1 $0 logo_IFTY))) $1 (logo_DIVA logo_UA $0)))) (logo_MULL logo_UL 6) 4)")
211
+ T("(#(lambda (lambda (#(lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $0 (lambda (lambda (logo_FWRT $4 $3 $0))))))) $1 $0 logo_IFTY))) $1 (logo_DIVA logo_UA $0)))) (logo_MULL logo_UL 5) 5)")
212
+ T("(#(lambda (lambda (#(lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $0 (lambda (lambda (logo_FWRT $4 $3 $0))))))) $1 $0 logo_IFTY))) $1 (logo_DIVA logo_UA $0)))) (logo_MULL logo_UL 3) 6)")
213
+ T("(#(lambda (lambda (#(lambda (lambda (#(lambda (lambda (lambda (logo_forLoop $0 (lambda (lambda (logo_FWRT $4 $3 $0))))))) $1 $0 logo_IFTY))) $1 (logo_DIVA logo_UA $0)))) (logo_MULL logo_UL 2) 7)")
214
+
215
+ # Spirals!
216
+ for spiralSize in [1,2,3,4,5]:
217
+ T(f"((lambda (logo_forLoop logo_IFTY (lambda (lambda (logo_FWRT (logo_MULL logo_epsL $1) (logo_MULA logo_epsA $2) $0))))) {spiralSize})")
218
+ for spiralSize in [5,6,7,8,9]:
219
+ #T(f"(lambda (#(lambda (logo_forLoop $0 (lambda (lambda (#(lambda (logo_FWRT (logo_MULL logo_UL $0) (logo_DIVA logo_UA 4))) $1 $0))))) {spiralSize} $0))")
220
+ T("(loop i " + str(spiralSize) + " (move (*d 1l i) (/a 1a 4)))")# (#(lambda (logo_forLoop $0 (lambda (lambda (#(lambda (logo_FWRT (logo_MULL logo_UL $0) (logo_DIVA logo_UA 4))) $1 $0))))) {spiralSize} $0))")
221
+
222
+ # CIRCLES
223
+ #(lambda (#(lambda (logo_forLoop 6 (lambda (lambda (#(lambda (lambda (logo_forLoop logo_IFTY (lambda (lambda (logo_FWRT $2 $3 $0)))))) logo_epsA (logo_MULL logo_epsL $2) $0))))) 6 $0))
224
+ for circleSize in [1,3,5,7,9]:
225
+ T(f"(lambda (#(lambda (logo_forLoop 6 (lambda (lambda (#(lambda (lambda (logo_forLoop logo_IFTY (lambda (lambda (logo_FWRT $2 $3 $0)))))) logo_epsA (logo_MULL logo_epsL $2) $0))))) {circleSize} $0))")
226
+
227
+ T("(loop i 3 (move (*d 1l 3) (/a 1a 4)))")
228
+ T("(loop i 5 (move (*d 1l 5) (/a 1a 5)))")
229
+ T("(loop i infinity (move (*d epsilonDistance 5) (/a epsilonAngle 3)))")
230
+ T("(loop i infinity (move (*d epsilonDistance 9) (/a epsilonAngle 2)))")
231
+ T("(loop i infinity (move (*d epsilonLength i) (*a epsilonAngle 3)))")
232
+ T("(loop i 9 (move (*d 1l i) (/a 1a 4)))")
233
+ T("(move 1d 0a)")
234
+ T("(loop i infinity (move (*d epsilonLength 6) epsilonAngle))")
235
+ T("(loop i infinity (move (*d epsilonLength 8) epsilonAngle))")
236
+ T("(loop k 2 (loop i infinity (move (*d epsilonLength 4) epsilonAngle)))")
237
+ T("(loop k 2 (loop i infinity (move (*d epsilonLength 8) epsilonAngle)))")
238
+ T("(loop s 4 (move (*d 1d 3) (/a 1a 4)))")
239
+ T("(loop s 4 (move (*d 1d 6) (/a 1a 4)))")
240
+ T("""
241
+ (loop j 5
242
+ (move 0d (/a 1a 5))
243
+ (embed (loop i infinity
244
+ (move (*d epsilonLength 6) epsilonAngle))
245
+ (loop i infinity
246
+ (move (*d epsilonLength 6) epsilonAngle))))""")
247
+ T("""
248
+ (loop j 5
249
+ (embed (loop s 4 (move (*d 1d 3) (/a 1a 4))))
250
+ (move 0d (/a 1a 5)))""")
251
+ return demos
252
+
253
+ def rotationalSymmetryDemo():
254
+ demos = []
255
+ def T(source):
256
+ demos.append(manualLogoTask(str(len(demos)), source))
257
+
258
+ body = {"dashed": "(p (move 1d 0a)) (move 1d 0a) (p (move 1d 0a)) (move 1d 0a)",
259
+ "lonely circle": "(p (move (*d 1d 2) 0a)) (loop k 2 (loop i infinity (move (*d epsilonLength 2) epsilonAngle)))",
260
+ "square dashed": "(p (move 1d 0a)) (loop s 4 (move 1d (/a 1a 4)))",
261
+ "square": "(loop s 4 (move (*d 1d 2) (/a 1a 4)))",
262
+ "semicircle": "(loop i infinity (move (*d epsilonLength 4) epsilonAngle))"}
263
+ for name in body:
264
+ for n in [3,4,5,6,7]:
265
+ T("""
266
+ (loop j %d
267
+ (embed %s)
268
+ (move 0d (/a 1a %d)))"""%(n,body[name],n))
269
+ return demos
270
+
271
+
272
+ def manualLogoTasks():
273
+ tasks = []
274
+ def T(name, source, needToTrain=False, supervise=False):
275
+ tasks.append(manualLogoTask(name, source, supervise=supervise,
276
+ needToTrain=needToTrain))
277
+ if False:
278
+ for d,a,s in [('1l','0a','(loop i infinity (move epsilonLength epsilonAngle))'),
279
+ ('epsilonLength','0a','(loop i infinity (move epsilonLength epsilonAngle))'),
280
+ ('(*d 1l 3)','0a','(move 1l 0a)'),
281
+ ('epsilonLength','0a','(move (*d 1l 2) 0a)'),
282
+ ('(*d epsilonLength 9)','0a','(move epsilonLength 0a)'),
283
+ ('(/d 1l 2)','0a','(move 1l 0a)')]:
284
+ # 'epsilonLength']:
285
+ # for a in ['epsilonAngle','0a']:
286
+ # for s in ['(move 1l 0a)',
287
+ # '(move epsilonLength 0a)',
288
+ # '(loop i infinity (move epsilonLength epsilonAngle))']:
289
+ # if d == 'epsilonLength' and s == '(move epsilonLength 0a)': continue
290
+ T("pu: %s/%s/%s"%(d,a,s),
291
+ """
292
+ (pu (move %s %s) pd %s)
293
+ """%(d,a,s))
294
+ return tasks
295
+
296
+ def slant(n):
297
+ return f"(move 0d (/a 1a {n}))"
298
+
299
+ for n,l,s in [(3,"1l",8),
300
+ (4,"(*d 1d 3)",None),
301
+ (5,"1l",None),
302
+ (6,"(*d 1d 2)",5),
303
+ (7,"1l",None),
304
+ (8,"(/d 1d 2)",None)]:
305
+ T(f"{n}-gon {l}{'' if s is None else ' slanted '+str(s)}",
306
+ f"""
307
+ ({'' if s is None else slant(s)}
308
+ (loop i {n}
309
+ (move {l} (/a 1a {n}))))
310
+ """,
311
+ needToTrain=True)
312
+ for n,l,s in [(3,"(*d 1l 2)",None),
313
+ (4,"(*d 1d 4)",None),
314
+ (5,"(*d 1d 2)",None),
315
+ (6,"1l",None),
316
+ (7,"(*d 1d 3)",None),
317
+ (8,"1l",3)]:
318
+ T(f"{n}-gon {l}{'' if s is None else ' slanted '+str(s)}",
319
+ f"""
320
+ ({'' if s is None else slant(s)}
321
+ (loop i {n}
322
+ (move {l} (/a 1a {n}))))
323
+ """,
324
+ needToTrain=False)
325
+
326
+
327
+
328
+ T("upwards", "((move 0d (/a 1a 4)) (move 1d 0a))",
329
+ needToTrain=True)
330
+ T("right angle", "((move (*d 1d 2) (/a 1a 4)) (move 1d 0a))",
331
+ needToTrain=True)
332
+ T("right angle epsilon", "((move epsilonLength (/a 1a 4)) (move epsilonLength 0a))",
333
+ needToTrain=True)
334
+
335
+ T("line segment", "(move 1d 0a)",
336
+ needToTrain=True)
337
+
338
+ T("square slanted by 2pi/3",
339
+ """((move 0d (/a 1a 3))
340
+ (loop k 4 (move 1d (/a 1a 4))))""",
341
+ needToTrain=True)
342
+ T("semicircle slanted by 2pi/5",
343
+ """((move 0d (/a 1a 5))
344
+ (loop i infinity
345
+ (move (*d epsilonLength 4) epsilonAngle)))""",
346
+ needToTrain=True)
347
+ T("Greek spiral slanted by 2pi/6",
348
+ """((move 0d (/a 1a 6))
349
+ (loop i 7 (move (*l 1l i) (/a 1a 4))))""",
350
+ needToTrain=True)
351
+ T("Hook slanted by 2pi/7",
352
+ """((move 0d (/a 1a 7))
353
+ (move 1d 0a)
354
+ (loop i infinity
355
+ (move (*d epsilonLength 4) epsilonAngle)))""",
356
+ needToTrain=True)
357
+ T("""slanted line""",
358
+ """((move 0d (/a 1a 8))
359
+ (move (*d 1l 3) 0a))""",
360
+ needToTrain=True)
361
+
362
+
363
+ for i in [6,7,8,9]:
364
+ T("Greek spiral %d"%i,
365
+ """
366
+ (loop i %d
367
+ (move (*l 1l i) (/a 1a 4)))
368
+ """%i,
369
+ needToTrain=i in [7,8])
370
+ for i in [2,3,4,5]:
371
+ T("smooth spiral %d"%i,
372
+ """
373
+ (loop i infinity
374
+ (move (*d epsilonLength i) (*a epsilonAngle %d)))
375
+ """%i,
376
+ needToTrain=i in [3,5])
377
+
378
+ T("smooth spiral 4 slanted by 2pi/2",
379
+ """
380
+ ((move 0d (/a 1a 2))
381
+ (loop i infinity
382
+ (move (*d epsilonLength i) (*a epsilonAngle 4))))
383
+ """,
384
+ needToTrain=True)
385
+
386
+ for i in [3,5,7,9]:
387
+ T("star %d"%i,
388
+ """
389
+ (loop i %d (move (*d 1d 4) (-a (/a 1a 2) (/a (/a 1a 2) %s))))
390
+ """%(i,i),
391
+ needToTrain=i in [5,9])
392
+
393
+ T("leaf iteration 1.1",
394
+ """
395
+ (loop i infinity (move epsilonDistance (/a epsilonAngle 2)))
396
+ """,
397
+ needToTrain=True)
398
+ T("leaf iteration 1.2",
399
+ """
400
+ ((move 0d (/a 1a 2))
401
+ (loop i infinity (move epsilonDistance (/a epsilonAngle 2))))
402
+ """,
403
+ needToTrain=True)
404
+ T("leaf iteration 2.1",
405
+ """
406
+ (loop n 2
407
+ (loop i infinity (move epsilonDistance (/a epsilonAngle 2)))
408
+ (move 0d (/a 1a 4)))
409
+ """,
410
+ needToTrain=True)
411
+ T("leaf iteration 2.2",
412
+ """
413
+ ((move 0d (/a 1a 2))
414
+ (loop n 2
415
+ (loop i infinity (move epsilonDistance (/a epsilonAngle 2)))
416
+ (move 0d (/a 1a 4))))
417
+ """,
418
+ needToTrain=True)
419
+ for n in range(3,8):
420
+ T("flower %d"%n,
421
+ """
422
+ (loop j %d
423
+ (loop n 2
424
+ (loop i infinity (move epsilonDistance (/a epsilonAngle 2)))
425
+ (move 0d (/a 1a 4)))
426
+ (move 0d (/a 1a %d)))
427
+ """%(n,n),
428
+ needToTrain=n in range(3,5))
429
+
430
+ for n in [5,6]:
431
+ T("staircase %d"%n,
432
+ """
433
+ (loop i %d
434
+ (move 1d (/a 1a 4))
435
+ (move 1d (/a 1a 4))
436
+ (move 0d (/a 1a 2)))
437
+ """%n,
438
+ needToTrain=n in [5])
439
+
440
+ for n in range(1,6):
441
+ T("blocks zigzag %d"%n,
442
+ """
443
+ (loop i %d
444
+ (move 1d (/a 1a 4)) (move 1d (/a 1a 4))
445
+ (move 1d (+a (/a 1a 2) (/a 1a 4))) (move 1d (+a (/a 1a 2) (/a 1a 4))))
446
+ """%n,
447
+ needToTrain=n in [1,2,3])
448
+ for n in [3,4]:#range(1,5):
449
+ T("diagonal zigzag %d"%n,
450
+ """
451
+ ((move 0d (/a 1a 8))
452
+ (loop i %d
453
+ (move 1d (/a 1a 4))
454
+ (move 1d (+a (/a 1a 2) (/a 1a 4)))))
455
+ """%n,
456
+ needToTrain=n == 4)
457
+
458
+
459
+
460
+ for n in [1,2,3,4,5,6]:
461
+ T("right semicircle of size %d"%n,
462
+ """
463
+ (loop i infinity
464
+ (move (*d epsilonLength %d) (-a 0a epsilonAngle)))
465
+ """%n,
466
+ needToTrain=n%2 == 0)
467
+ T("left semicircle of size %d"%n,
468
+ f"""
469
+ ({'' if n != 1 else slant(8)}
470
+ (loop i infinity
471
+ (move (*d epsilonLength {n}) epsilonAngle)))
472
+ """,
473
+ needToTrain=n%2 == 1)
474
+ T("circle of size %d"%n,
475
+ """
476
+ ((loop i infinity
477
+ (move (*d epsilonLength %d) epsilonAngle))
478
+ (loop i infinity
479
+ (move (*d epsilonLength %d) epsilonAngle)))
480
+ """%(n,n),
481
+ needToTrain=n in [1,4,3,5,6])
482
+
483
+ for n in [5,6]:
484
+ T("%d enclosed circles"%n,
485
+ """
486
+ (loop j %d
487
+ (loop i infinity
488
+ (move (*d epsilonLength j) epsilonAngle))
489
+ (loop i infinity
490
+ (move (*d epsilonLength j) epsilonAngle)))"""%n,
491
+ needToTrain=n == 5)
492
+
493
+ for n,l in [(4,2),
494
+ (5,3),
495
+ (6,4),
496
+ (3,1)]:
497
+ T("%d-circle flower l=%d"%(n,l),
498
+ """
499
+ (loop j %d
500
+ (move 0d (/a 1a %d))
501
+ (embed (loop i infinity
502
+ (move (*d epsilonLength %d) epsilonAngle))
503
+ (loop i infinity
504
+ (move (*d epsilonLength %d) epsilonAngle))))"""%(n,n,l,l),
505
+ needToTrain=(n,l) in [(6,4),(3,1)])
506
+
507
+ for n,l in [(3,1),(2,2),(1,3),
508
+ (2,1),(1,2),(1,1)]:
509
+ T("%d-semicircle sequence L=%d"%(n,l),
510
+ """
511
+ (loop j %d
512
+ (loop i infinity
513
+ (move (*d epsilonLength %d) epsilonAngle))
514
+ (loop i infinity
515
+ (move (*d epsilonLength %d) (-a 0a epsilonAngle))))
516
+ """%(n,l,l),
517
+ needToTrain=(n,l) in [(3,1),(2,2),(1,3)])
518
+
519
+ for n,l in [(2,"1d"),
520
+ (3,"1d")]:
521
+ T("row of %d circles"%n,
522
+ """
523
+ (loop j %d
524
+ (embed (loop k 2 (loop i infinity (move epsilonLength epsilonAngle))))
525
+ (p (move %s 0a)))"""%(n,l),
526
+ needToTrain=n == 2)
527
+ for n,l in [(2,"1d"),
528
+ (3,"1d")]:
529
+ T("row of %d lines"%n,
530
+ """
531
+ (loop j %d
532
+ (move 1d 0a)
533
+ (p (move %s 0a)))"""%(n,l),
534
+ needToTrain=n == 2)
535
+ T("line next to semicircle",
536
+ """
537
+ ((move 1d 0a) (p (move 1d 0a)) (loop i infinity (move epsilonLength epsilonAngle)))
538
+ """,
539
+ needToTrain=True)
540
+ for n,l in [(3,"(/d 1d 2)"),
541
+ (4,"(/d 1d 3)")]:
542
+ T("%d dashed lines of size %s"%(n,l),
543
+ """(loop i %d (p (move 1d 0a)) (move %s 0a))"""%(n,l),
544
+ needToTrain=n == 3)
545
+ T("broken circle",
546
+ """
547
+ ((loop i infinity (move epsilonLength epsilonAngle)) (p (move 1d 0a)) (loop i infinity (move epsilonLength epsilonAngle)))
548
+ """,
549
+ needToTrain=True)
550
+ T("circle next to semicircle",
551
+ """
552
+ ((loop i infinity (move epsilonLength epsilonAngle))
553
+ (loop i infinity (move epsilonLength epsilonAngle))
554
+ (p (move 1d 0a))
555
+ (loop i infinity (move epsilonLength epsilonAngle)))
556
+ """,
557
+ needToTrain=True)
558
+ T("semicircle next to square",
559
+ """
560
+ ((loop i infinity (move epsilonLength epsilonAngle))
561
+ (p (move 1d 0a))
562
+ (loop i infinity (move 1d (/a 1a 4))))
563
+ """,
564
+ needToTrain=False)
565
+ T("circle next to square",
566
+ """
567
+ ((loop i infinity (move epsilonLength epsilonAngle))
568
+ (loop i infinity (move epsilonLength epsilonAngle))
569
+ (p (move 1d 0a))
570
+ (loop i infinity (move 1d (/a 1a 4))))
571
+ """,
572
+ needToTrain=False)
573
+ T("circle next to line",
574
+ """
575
+ ((loop i infinity (move epsilonLength epsilonAngle))
576
+ (loop i infinity (move epsilonLength epsilonAngle))
577
+ (p (move 1d 0a))
578
+ (move 1d 0a))
579
+ """,
580
+ needToTrain=True)
581
+ T("line next to circle",
582
+ """
583
+ ((move 1d 0a)
584
+ (p (move 1d 0a))
585
+ (loop i infinity (move epsilonLength epsilonAngle))
586
+ (loop i infinity (move epsilonLength epsilonAngle))
587
+ (move 1d 0a))
588
+ """,
589
+ needToTrain=True)
590
+ for n,l in [(4,"1d"),
591
+ (5,"1d")]:
592
+ T("row of %d dashes"%n,
593
+ """
594
+ (loop j %d
595
+ (embed (move 0d (/a 1a 4)) (move 1d 0a))
596
+ (p (move %s 0a)))"""%(n,l),
597
+ needToTrain=n == 4)
598
+ for n,l in [(5,"1d"),(6,"1d")]:
599
+ T("row of %d semicircles"%n,
600
+ """
601
+ (loop j %d
602
+ (embed (loop i infinity (move epsilonLength epsilonAngle)))
603
+ (p (move %s 0a)))"""%(n,l),
604
+ needToTrain=n == 5)
605
+
606
+ with random_seed(42): # carefully selected for maximum entropy
607
+ for n in [3,4,5,6,7]:
608
+ body = {"empty": "(move 1d 0a)",
609
+ "spiral": "(loop i infinity (move (*d epsilonLength i) (*a epsilonAngle 2)))",
610
+ "dashed": "(p (move 1d 0a)) (move 1d 0a)",
611
+ "circle": "(move 1d 0a) (loop k 2 (loop i infinity (move epsilonLength epsilonAngle)))",
612
+ "lonely circle": "(p (move 1d 0a)) (loop k 2 (loop i infinity (move epsilonLength epsilonAngle)))",
613
+ "square dashed": "(p (move 1d 0a)) (loop s 4 (move 1d (/a 1a 4)))",
614
+ "square": "(move 1d 0a) (loop s 4 (move 1d (/a 1a 4)))",
615
+ "close large semicircle": "(loop i infinity (move (*d epsilonLength 2) epsilonAngle))",
616
+ "close semicircle": "(loop i infinity (move epsilonLength epsilonAngle))",
617
+ "semicircle": "(move 1d 0a) (loop i infinity (move epsilonLength epsilonAngle))",
618
+ "double dashed": "(p (move 1d 0a)) (move 1d 0a) (p (move 1d 0a)) (move 1d 0a)",
619
+ "Greek": "(loop i 3 (move (*l 1l i) (/a 1a 4)))"}
620
+ for name in body:
621
+ if name == "spiral" and n not in [3,5]: continue
622
+ if name == "square" and n not in [5,3,6,7]: continue
623
+ if name == "semicircle" and n not in [5,3,4,6]: continue
624
+ if name == "Greek" and n not in [3,5]: continue
625
+ if name == "double dashed" and n not in [6,4,3]: continue
626
+
627
+ mustTrain = False
628
+
629
+ mustTrain = mustTrain or (n == 3 and name == "Greek")
630
+ mustTrain = mustTrain or (n == 7 and name == "empty")
631
+ mustTrain = mustTrain or (n == 5 and name == "dashed")
632
+ mustTrain = mustTrain or (n == 7 and name == "circle")
633
+ mustTrain = mustTrain or (n == 6 and name == "circle")
634
+ mustTrain = mustTrain or (n == 6 and name == "lonely circle")
635
+ mustTrain = mustTrain or (n == 5 and name == "square")
636
+ mustTrain = mustTrain or (n == 7 and name == "square")
637
+ mustTrain = mustTrain or (n == 5 and name == "semicircle")
638
+ mustTrain = mustTrain or (n == 3 and name == "square dashed")
639
+ mustTrain = mustTrain or (n == 6 and name == "close semicircle")
640
+ mustTrain = mustTrain or (n == 5 and name == "close large semicircle")
641
+ mustTrain = mustTrain or (n == 3 and name == "spiral")
642
+ mustTrain = mustTrain or (n == 6 and name == "double dashed")
643
+ mustTrain = mustTrain or (n == 3 and name == "double dashed")
644
+ #mustTrain = mustTrain or (n == 6 and name == "empty")
645
+
646
+ #mustTrain = mustTrain or (random.random() < 0.07) # calibrated to give 70 training tasks
647
+
648
+
649
+ # # cap number of super easy snowflakes
650
+ # if name == "empty" and n not in [7]: mustTrain = False
651
+ # if name == "dashed" and n not in [4]: mustTrain = False
652
+
653
+
654
+ T("%d-%s snowflake"%(n,name),
655
+ """
656
+ (loop j %d
657
+ (embed %s)
658
+ (move 0d (/a 1a %d)))"""%(n,body[name],n),
659
+ needToTrain=mustTrain)
660
+
661
+ for n in [3,4]:#2,3,4]:
662
+ T("%d-row of squares"%n,
663
+ """
664
+ (loop i %d
665
+ (embed (loop k 4 (move 1d (/a 1a 4))))
666
+ (move 1d 0a))
667
+ """%n,
668
+ needToTrain=n == 4)
669
+ T("2x2 grid",
670
+ """
671
+ (for x 2 (embed (for y 2
672
+ (embed (loop k 4 (move 1d (/a 1a 4))))
673
+ (move 1d 0a)))
674
+ (move 0d (/a 1a 4)) (move 1d (-a 0a (/a 1a 4))))
675
+ """)
676
+ T("slanted squares",
677
+ """
678
+ ((embed (loop k 4 (move 1d (/a 1a 4))))
679
+ (move 0d (/a 1a 8))
680
+ (loop k 4 (move 1d (/a 1a 4))))
681
+ """)
682
+ for l in range(1,6):
683
+ T("square of size %d"%l,
684
+ """
685
+ (for i 4
686
+ (move (*d 1d %d) (/a 1a 4)))
687
+ """%l,
688
+ needToTrain=l in range(4))
689
+ for n in [5,7]:
690
+ T("%d-concentric squares"%n,
691
+ """
692
+ (for i %d
693
+ (embed (loop j 4 (move (*d 1d i) (/a 1a 4)))))
694
+ """%n,
695
+ needToTrain=n == 5)
696
+ return tasks
697
+
698
+ def montageTasks(tasks, prefix="", columns=None, testTrain=False):
699
+ import numpy as np
700
+
701
+ w = 128
702
+ arrays = [t.highresolution for t in tasks]
703
+ for a in arrays:
704
+ assert len(a) == w*w
705
+
706
+ if testTrain:
707
+ arrays = [a for a,t in zip(arrays, tasks) if t.mustTrain ] + [a for a,t in zip(arrays, tasks) if not t.mustTrain ]
708
+
709
+ arrays = [np.array([a[i:i + w]
710
+ for i in range(0, len(a), w) ])
711
+ for a in arrays]
712
+ i = montage(arrays, columns=columns)
713
+
714
+ import scipy.misc
715
+ scipy.misc.imsave('/tmp/%smontage.png'%prefix, i)
716
+ if testTrain:
717
+ trainingTasks = arrays[:sum(t.mustTrain for t in tasks)]
718
+ testingTasks = arrays[sum(t.mustTrain for t in tasks):]
719
+ random.shuffle(trainingTasks)
720
+ random.shuffle(testingTasks)
721
+ arrays = trainingTasks + testingTasks
722
+ else:
723
+ random.shuffle(arrays)
724
+ scipy.misc.imsave('/tmp/%srandomMontage.png'%prefix, montage(arrays, columns=columns))
725
+
726
+ def demoLogoTasks():
727
+ import scipy.misc
728
+ import numpy as np
729
+
730
+ g0 = Grammar.uniform(primitives, continuationType=turtle)
731
+ eprint("dreaming into /tmp/dreams_0...")
732
+ N = 1000
733
+ programs = [ p
734
+ for _ in range(N)
735
+ for p in [g0.sample(arrow(turtle,turtle),
736
+ maximumDepth=20)]
737
+ if p is not None]
738
+ os.system("mkdir -p /tmp/dreams_0")
739
+ for n,p in enumerate(programs):
740
+ with open(f"/tmp/dreams_0/{n}.dream","w") as handle:
741
+ handle.write(str(p))
742
+ drawLogo(*programs, pretty=True, smoothPretty=False,
743
+ resolution=512,
744
+ filenames=[f"/tmp/dreams_0/{n}_pretty.png"
745
+ for n in range(len(programs)) ],
746
+ timeout=1)
747
+
748
+ if len(sys.argv) > 1:
749
+ tasks = makeTasks(sys.argv[1:],proto=False)
750
+ else:
751
+ tasks = makeTasks(['all'],proto=False)
752
+ montageTasks(tasks,columns=16,testTrain=True)
753
+ for n,t in enumerate(tasks):
754
+ a = t.highresolution
755
+ w = int(len(a)**0.5)
756
+ scipy.misc.imsave('/tmp/logo%d.png'%n, np.array([a[i:i+w]
757
+ for i in range(0,len(a),w) ]))
758
+ logo_safe_name = t.name.replace("=","_").replace(' ','_').replace('/','_').replace("-","_") + ".png"
759
+ #os.system(f"convert /tmp/logo{n}.png -morphology Dilate Octagon /tmp/{logo_safe_name}")
760
+ os.system(f"convert /tmp/logo{n}.png -channel RGB -negate /tmp/{logo_safe_name}")
761
+ eprint(len(tasks),"tasks")
762
+ eprint(sum(t.mustTrain for t in tasks),"need to be trained on")
763
+
764
+ for t in dSLDemo():
765
+ a = t.highresolution
766
+ w = int(len(a)**0.5)
767
+ scipy.misc.imsave('/tmp/logoDemo%s.png'%t.name, np.array([a[i:i+w]
768
+ for i in range(0,len(a),w) ]))
769
+ os.system(f"convert /tmp/logoDemo{t.name}.png -morphology Dilate Octagon /tmp/logoDemo{t.name}_dilated.png")
770
+
771
+ tasks = [t for t in tasks if t.mustTrain ]
772
+ random.shuffle(tasks)
773
+ montageTasks(tasks[:16*3],"subset",columns=16)
774
+
775
+ montageTasks(rotationalSymmetryDemo(),"rotational")
776
+
777
+
dreamcoder/domains/misc/RobustFillPrimitives.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #RobustFillPrimitives
2
+
3
+ from dreamcoder.program import Primitive, prettyProgram
4
+ from dreamcoder.grammar import Grammar
5
+ from dreamcoder.type import tint, arrow, baseType #, t0, t1, t2
6
+
7
+ from string import printable
8
+ import re
9
+ from collections import defaultdict
10
+
11
+ #from functools import reduce
12
+
13
+
14
+ disallowed = [
15
+ ("#", "hash"),
16
+ ("!", "bang"),
17
+ ("\"", "double_quote"),
18
+ ("$", "dollar"),
19
+ ("%", "percent"),
20
+ ("&", "ampersand"),
21
+ ("'", "single_quote"),
22
+ (")", "left_paren"),
23
+ ("(", "right_paren"),
24
+ ("*", "astrisk"),
25
+ ("+", "plus"),
26
+ (",", "comma"),
27
+ ("-", "dash"),
28
+ (".", "period"),
29
+ ("/", "slash"),
30
+ (":", "colon"),
31
+ (";", "semicolon"),
32
+ ("<", "less_than"),
33
+ ("=", "equal"),
34
+ (">", "greater_than"),
35
+ ("?", "question_mark"),
36
+ ("@", "at"),
37
+ ("[", "left_bracket"),
38
+ ("\\", "backslash"),
39
+ ("]", "right_bracket"),
40
+ ("^", "carrot"),
41
+ ("_", "underscore"),
42
+ ("`", "backtick"),
43
+ ("|", "bar"),
44
+ ("}", "right_brace"),
45
+ ("{", "left_brace"),
46
+ ("~", "tilde"),
47
+ (" ", "space"),
48
+ ("\t", "tab")
49
+ ]
50
+ disallowed = dict(disallowed)
51
+ delimiters = "&,.?!@()[]%{/}:;$#\"'"
52
+
53
+ delim_dict = {disallowed[c]:c for c in delimiters}
54
+
55
+ types = {}
56
+ types["Number"] = r'\d+'
57
+ types["Word"] = r'\w+'
58
+ types["Alphanum"] = r'\w'
59
+ types["PropCase"] = r'[A-Z][a-z]+'
60
+ types["AllCaps"] = r'[A-Z]'
61
+ types["Lower"] = r'[a-z]'
62
+ types["Digit"] = r'\d'
63
+ types["Char"] = r'.'
64
+
65
+ regexes = {name: re.escape(val) for name, val in delim_dict.items()}
66
+ regexes = {**regexes, **types}
67
+
68
+ tposition = baseType("position")
69
+ tindex = baseType("index")
70
+ tcharacter = baseType("character")
71
+ tboundary = baseType("boundary")
72
+ tregex = baseType("regex")
73
+ tsubstr = baseType("substr")
74
+ texpression = baseType("expression")
75
+ tprogram = baseType("program")
76
+ tnesting = baseType("nesting")
77
+ ttype = baseType("type")
78
+ tdelimiter = baseType("delimiter")
79
+
80
+ def _substr(k1): return lambda k2: lambda string: string[k1:k2] #i think this is fine
81
+ def _getspan(r1):
82
+ return lambda i1: lambda b1: lambda r2: lambda i2: lambda b2: lambda string: \
83
+ string[
84
+ [m.end() for m in re.finditer(r1, string)][i1] if b1 == "End" else [m.start() for m in re.finditer(r1, string)][i1]:[m.end() for m in re.finditer(r2, string)][i2] if b2 == "End" else [m.start() for m in re.finditer(r2, string)][i2]
85
+ ]
86
+ #TODO format correctly
87
+ def _getspan_const(r1): return lambda i1: lambda b1: lambda r2: lambda i2: lambda b2: (defaultdict(int, {r1:i1+1 if i1>=0 else abs(i1), r2:i2+1 if i2>=0 else abs(i2)}), max(i1+1 if i1>=0 else abs(i1), i2+1 if i2>=0 else abs(i2)))
88
+
89
+
90
+ def _trim(string):
91
+ assert False
92
+ return string
93
+
94
+ def _replace(d1, d2): return lambda string: string.replace(d1,d2)
95
+
96
+ def _getall(tp): return lambda string: ''.join(re.findall(tp, string))
97
+ def _getfirst(tp, i): return lambda string: ''.join(re.findall(tp, string)[:i])
98
+ def _gettoken(tp, i): return lambda string: re.findall(tp, string)[i]
99
+ def _gettoken_const(tp, i): return defaultdict(int, {tp: i+1 if i>=0 else abs(i)}), i+1 if i>=0 else abs(i)
100
+
101
+ def _getupto(reg): return lambda string: string[:[m.end() for m in re.finditer(reg, string)][0]]
102
+ def _getfrom(reg): return lambda string: string[[m.end() for m in re.finditer(reg, string)][-1]:]
103
+
104
+ def _concat2(expr1): return lambda expr2: lambda string: expr1(string) + expr2(string) #More concats plz
105
+ def _concat1(expr): return lambda string: expr(string)
106
+ def _concat_list(expr): return lambda program: lambda string: expr(string) + program(string)
107
+ #i've decided that all of the things which are expressions should take tstring as last input and output a tstring. Thus, all requests are arrow(tstring, tstring) and we limit size with recursive depth
108
+ """
109
+ todo:
110
+ - _trim
111
+ - incorporate tcharacter
112
+ - constraints
113
+ - format _getspan
114
+ - figure out how to represent on top_level
115
+
116
+ - flatten for nn
117
+ - parse
118
+
119
+ - robustfill_util
120
+ - train dc model for robustfill
121
+ - main_supervised_robustfill
122
+ - evaluate_robustfill
123
+ - sample_data
124
+
125
+
126
+ - deal with escapes ...
127
+
128
+ constraints:
129
+ elements, and number necessary, and lengths
130
+ """
131
+
132
+ def robustFillPrimitives(max_len=100, max_index=5):
133
+ return [
134
+ #CPrimitive("concat2", arrow(texpression, texpression, tprogram), _concat2),
135
+ CPrimitive("concat1", arrow(texpression, tprogram), _concat1),
136
+ CPrimitive("concat_list", arrow(texpression, tprogram, tprogram), _concat_list),
137
+ #expressions
138
+ CPrimitive("Constant", arrow(tcharacter, texpression), lambda x: lambda y: x), # add a constraint
139
+ CPrimitive("apply", arrow(tnesting, tsubstr, texpression), lambda n: lambda sub: lambda string: n(sub(string))),
140
+ CPrimitive("apply_n", arrow(tnesting, tnesting, texpression), lambda n1: lambda n2: lambda string: n1(n2(string))),
141
+ CPrimitive("expr_n", arrow(tnesting, texpression), lambda x: x),
142
+ CPrimitive("expr_f", arrow(tsubstr, texpression), lambda x: x)
143
+ ] + [
144
+ #substrings
145
+ CPrimitive("SubStr", arrow(tposition, tposition, tsubstr), _substr), # handled
146
+ CPrimitive("GetSpan", arrow(tregex, tindex, tboundary, tregex, tindex, tboundary, tsubstr), _getspan, _getspan_const) #TODO constraint
147
+ ] + [
148
+ #nestings
149
+ CPrimitive("GetToken"+name+str(i), tnesting, _gettoken(tp,i), _gettoken_const(tp, i)) for name, tp in types.items() for i in range(-max_index, max_index)
150
+ ] + [
151
+ CPrimitive("ToCase_ProperCase", tnesting, lambda x: x.title(), (defaultdict(int, {r'[A-Z][a-z]+':1}), 1)),
152
+ CPrimitive("ToCase_AllCapsCase", tnesting, lambda x: x.upper(), (defaultdict(int, {r'[A-Z]':1}) ,1)),
153
+ CPrimitive("ToCase_LowerCase", tnesting, lambda x: x.lower(), (defaultdict(int, {r'[a-z]':1}), 1) )
154
+ ] + [
155
+ CPrimitive("Replace_"+name1+name2, tnesting, _replace(char1, char2), (defaultdict(int, {char1:1}), 1)) for name1, char1 in delim_dict.items() for name2, char2 in delim_dict.items() if char1 is not char2
156
+ ] + [
157
+ #CPrimitive("Trim", tnesting, _trim), #TODO
158
+ ] + [
159
+ CPrimitive("GetUpTo"+name, tnesting, _getupto(reg), (defaultdict(int, {reg:1} ),1)) for name, reg in regexes.items()
160
+ ] + [
161
+ CPrimitive("GetFrom"+name, tnesting, _getfrom(reg), (defaultdict(int, {reg:1} ),1)) for name, reg in regexes.items()
162
+ ] + [
163
+ CPrimitive("GetFirst_"+name+str(i), tnesting, _getfirst(tp, i), (defaultdict(int, {tp:i} ), i+1 if i>=0 else abs(i))) for name, tp in types.items() for i in list(range(-max_index,0))+ list(range(1,max_index+1))
164
+ ] + [
165
+ CPrimitive("GetAll_"+name, tnesting, _getall(reg),(defaultdict(int, {reg:1} ),1) ) for name, reg in types.items()
166
+ ] + [
167
+ #regexes
168
+ CPrimitive("type_to_regex", arrow(ttype, tregex), lambda x: x), #TODO also make disappear
169
+ CPrimitive("delimiter_to_regex", arrow(tdelimiter, tregex), lambda x: re.escape(x)) #TODO also make disappear
170
+ ] + [
171
+ #types
172
+ CPrimitive("Number", ttype, r'\d+', r'\d+'), #TODO
173
+ CPrimitive("Word", ttype, r'\w+', r'\w+'), #TODO
174
+ CPrimitive("Alphanum", ttype, r'\w', r'\w'), #TODO
175
+ CPrimitive("PropCase", ttype, r'[A-Z][a-z]+', r'[A-Z][a-z]+'), #TODO
176
+ CPrimitive("AllCaps", ttype, r'[A-Z]', r'[A-Z]'), #TODO
177
+ CPrimitive("Lower", ttype, r'[a-z]', r'[a-z]'), #TODO
178
+ CPrimitive("Digit", ttype, r'\d', r'\d'), #TODO
179
+ CPrimitive("Char", ttype, r'.', r'.') #TODO
180
+ ] + [
181
+ #Cases
182
+ # CPrimitive("ProperCase", tcase, .title()), #TODO
183
+ # CPrimitive("AllCapsCase", tcase, .upper()), #TODO
184
+ # CPrimitive("LowerCase", tcase, .lower()) #TODO
185
+ ] + [
186
+ #positions
187
+ CPrimitive("position"+str(i), tposition, i, (defaultdict(int), i+1 if i>=0 else abs(i)) ) for i in range(-max_len,max_len+1) #deal with indicies
188
+ ] + [
189
+ #indices
190
+ CPrimitive("index"+str(i), tindex, i, i) for i in range(-max_index,max_index+1) #deal with indicies
191
+ ] + [
192
+ #characters
193
+ CPrimitive(i, tcharacter, i, (defaultdict(int, {i:1}),1) ) for i in printable[:-5] if i not in disallowed
194
+ ] + [
195
+ CPrimitive(name, tcharacter, char, (defaultdict(int, {char:1}), 1)) for char, name in disallowed.items() # NB: disallowed is reversed
196
+ ] + [
197
+ #delimiters
198
+ CPrimitive("delim_"+name, tdelimiter, char, char) for name, char in delim_dict.items()
199
+ ] + [
200
+ #boundaries
201
+ CPrimitive("End", tboundary, "End"),
202
+ CPrimitive("Start", tboundary, "Start")
203
+ ]
204
+
205
+
206
+
207
+ def RobustFillProductions(max_len=100, max_index=5):
208
+ return [(0.0, prim) for prim in robustFillPrimitives(max_len=max_len, max_index=max_index)]
209
+
210
+
211
+ def flatten_program(p):
212
+ string = p.show(False)
213
+ string = string.replace('(', '')
214
+ string = string.replace(')', '')
215
+ #remove '_fn' (optional)
216
+ string = string.split(' ')
217
+ string = list(filter(lambda x: x is not '', string))
218
+ return string
219
+
220
+
221
+
222
+
223
+ def add_constraints(c1, c2=None):
224
+ if c2 is None:
225
+ return c1
226
+ d1, m1 = c1
227
+ d2, m2 = c2
228
+ min_size = max(m1, m2)
229
+ d = defaultdict(int)
230
+ for item in set(d1.keys()) | set(d2.keys()):
231
+ d[item] = max(d1[item], d2[item])
232
+ return d, min_size
233
+
234
+ # class Constraint_prop:
235
+ # def application(self, p, environment):
236
+ # self.f.visit(self, environment)(self.x.visit(self, environment))
237
+ # def primitive(self, p, environment):
238
+ # return self.value
239
+
240
+ class Constraint_prop:
241
+ def __init__(self):
242
+ pass
243
+
244
+ def application(self, p):
245
+ return p.f.visit(self)(p.x.visit(self))
246
+
247
+ def primitive(self, p):
248
+ return p.constraint
249
+
250
+ def execute(self, p):
251
+ return p.visit(self)
252
+
253
+
254
+ class CPrimitive(Primitive):
255
+ def __init__(self, name, ty, value, constraint=None):
256
+ #I have no idea why this works but it does .....
257
+ if constraint is None:
258
+ if len(ty.functionArguments())==0:
259
+ self.constraint = (defaultdict(int), 0)
260
+ elif len(ty.functionArguments())==1:
261
+ self.constraint = lambda x: x
262
+ elif len(ty.functionArguments())==2:
263
+ self.constraint = lambda x: lambda y: add_constraints(x,y)
264
+ else:
265
+ self.constraint = lambda x: x
266
+ for _ in range(len(ty.functionArguments()) - 1):
267
+ self.constraint = lambda x: lambda y: add_constraints(x, self.constraint(y))
268
+ else: self.constraint = constraint
269
+ super(CPrimitive, self).__init__(name, ty, value)
270
+
271
+ #def __getinitargs__(self):
272
+ # return (self.name, self.tp, self.value, None)
273
+
274
+ def __getstate__(self):
275
+ #print("self.name", self.name)
276
+ return self.name
277
+
278
+ def __setstate__(self, state):
279
+ #for backwards compatibility:
280
+ if type(state) == dict:
281
+ pass #do nothing, i don't need to load them if they are old...
282
+ else:
283
+ p = Primitive.GLOBALS[state]
284
+ self.__init__(p.name, p.tp, p.value, p.constraint)
285
+
286
+
287
+
288
+ if __name__=='__main__':
289
+ import time
290
+ CPrimitive("testCPrim", tint, lambda x: x, 17)
291
+ g = Grammar.fromProductions(RobustFillProductions())
292
+ print(len(g))
293
+ request = tprogram
294
+ p = g.sample(request)
295
+ print("request:", request)
296
+ print("program:")
297
+ print(prettyProgram(p))
298
+ s = 'abcdefg'
299
+ e = p.evaluate([])
300
+ #print("prog applied to", s)
301
+ #print(e(s))
302
+ print("flattened_program:")
303
+ flat = flatten_program(p)
304
+ print(flat)
305
+ t = time.time()
306
+ constraints = Constraint_prop().execute(p)
307
+ print(time.time() - t)
308
+ print(constraints)
dreamcoder/domains/misc/__init__.py ADDED
File without changes
dreamcoder/domains/misc/algolispPrimitives.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #napsPrimitives.py
2
+ from dreamcoder.program import Primitive, Program
3
+ from dreamcoder.grammar import Grammar
4
+ from dreamcoder.type import tlist, arrow, baseType #, t0, t1, t2
5
+
6
+ #from functools import reduce
7
+
8
+
9
+ #Internal TYPES:
10
+ # NUMBER
11
+ # BOOLEAN
12
+ # NOTFUNCTYPE
13
+ # Type
14
+ # ANYTYPE
15
+
16
+ #types
17
+ tsymbol = baseType("symbol")
18
+ #PROGRAM = SYMBOL = constant | argument | function_call | function | lambda
19
+ tconstant = baseType("constant")
20
+ tfunction = baseType("function")
21
+
22
+ f = dict([("|||","triple_or"),
23
+ ("reduce","reduce"),
24
+ ("+","+"),
25
+ ("len","len"),
26
+ ("map","map"),
27
+ ("filter","filter"),
28
+ ("-","-"),
29
+ ("*","*"),
30
+ ("partial0","partial0"),
31
+ ("if","if"),
32
+ ("lambda1","lambda1"),
33
+ ("==","eq"),
34
+ ("range","range"),
35
+ ("digits","digits"),
36
+ ("slice","slice"),
37
+ ("reverse","reverse"),
38
+ ("lambda2","lambda2"),
39
+ ("deref","deref"),
40
+ ("partial1","partial1"),
41
+ ("/","div"),
42
+ ("<","less_than"),
43
+ (">","greater_than"),
44
+ ("min","min"),
45
+ ("combine","combine"),
46
+ ("head","head"),
47
+ ("is_prime","is_prime"),
48
+ ("false","false"),
49
+ ("||","or"),
50
+ ("10","10"),
51
+ ("self","self"),
52
+ ("max","max"),
53
+ ("sort","sort"),
54
+ ("%","mod"),
55
+ ("invoke1","invoke1"),
56
+ ("!","bang"),
57
+ ("square","square"),
58
+ ("str_concat","str_concat"),
59
+ ("strlen","strlen"),
60
+ ("<=","leq"),
61
+ ("int-deref","int-deref"),
62
+ ("str_split","str_split"),
63
+ ("str_index","str_index"),
64
+ ("floor","floor"),
65
+ ("sqrt","sqrt"),
66
+ ("str_min","str_min"),
67
+ ("&&","AND"),
68
+ ("is_sorted","is_sorted"),
69
+ ("str_max","str_max"),
70
+ (">=","geq")])
71
+
72
+ fn_lookup = {
73
+ **f
74
+ }
75
+
76
+ c = dict(
77
+ [("0","0"),
78
+ ("a","a"),
79
+ ("arg1","arg1"),
80
+ ("1","1"),
81
+ ("b","b"),
82
+ ("2","2"),
83
+ ("c","c"),
84
+ ("arg2","arg2"),
85
+ ("d","d"),
86
+ ("false","false"),
87
+ ("10","10"),
88
+ ("self","self"),
89
+ ("1000000000","1000000000"),
90
+ ("\"\"", "empty_str"),
91
+ ("e","e"),
92
+ ("40","40"),
93
+ ("f","f"),
94
+ ("\" \"", "space"),
95
+ ("g","g"),
96
+ ("\"z\"","z"),
97
+ ("true","true"),
98
+ ("h","h"),
99
+ ("i","i"),
100
+ ("j","j"),
101
+ ("k","k"),
102
+ ("l","l")]
103
+ )
104
+
105
+ const_lookup = {
106
+ **c
107
+ }
108
+
109
+ primitive_lookup = {**const_lookup, **fn_lookup}
110
+ #Do i need arguments??
111
+
112
+ def _fn_call(f):
113
+ #print("f", f)
114
+ def inner(sx):
115
+ #print("sx", sx)
116
+ if not type(sx) == list:
117
+ sx = [sx]
118
+ return [f] + sx
119
+ return lambda sx: inner(sx)
120
+
121
+ def algolispPrimitives():
122
+ return [
123
+ Primitive("fn_call", arrow(tfunction, tlist(tsymbol), tsymbol), _fn_call),
124
+
125
+ Primitive("lambda1_call", arrow(tfunction, tlist(tsymbol), tsymbol), lambda f: lambda sx: ["lambda1", [f] + sx] if type(sx)==list else ["lambda1", [f] + [sx]] ),
126
+ Primitive("lambda2_call", arrow(tfunction, tlist(tsymbol), tsymbol), lambda f: lambda sx: ["lambda2", [f] + sx] if type(sx)==list else ["lambda2", [f] + [sx]] ),
127
+ #symbol converters:
128
+ # SYMBOL = constant | argument | function_call | function | lambda
129
+ Primitive("symbol_constant", arrow(tconstant, tsymbol), lambda x: x),
130
+ Primitive("symbol_function", arrow(tfunction, tsymbol), lambda x: x),
131
+ #list converters
132
+ Primitive('list_init_symbol', arrow(tsymbol, tlist(tsymbol)), lambda symbol: [symbol] ),
133
+ Primitive('list_add_symbol', arrow(tsymbol, tlist(tsymbol), tlist(tsymbol)), lambda symbol: lambda symbols: symbols + [symbol] if type(symbols) == list else [symbols] + [symbol])
134
+ ] + [
135
+ #functions:
136
+ Primitive(ec_name, tfunction, algo_name) for algo_name, ec_name in fn_lookup.items()
137
+ ] + [
138
+ #Constants
139
+ Primitive(ec_name, tconstant, algo_name) for algo_name, ec_name in const_lookup.items()
140
+ ]
141
+
142
+
143
+ #for first pass, can just hard code vars and maps n stuff
144
+
145
+
146
+ def algolispProductions():
147
+ return [(0.0, prim) for prim in algolispPrimitives()]
148
+
149
+
150
+ algolisp_input_vocab = [
151
+ "<S>",
152
+ "</S>",
153
+ "<UNK>",
154
+ "|||",
155
+ "(",
156
+ ")",
157
+ "a",
158
+ "b",
159
+ "of",
160
+ "the",
161
+ "0",
162
+ ",",
163
+ "arg1",
164
+ "c",
165
+ "and",
166
+ "1",
167
+ "reduce",
168
+ "+",
169
+ "int[]",
170
+ "in",
171
+ "given",
172
+ "numbers",
173
+ "int",
174
+ "is",
175
+ "len",
176
+ "map",
177
+ "digits",
178
+ "d",
179
+ "number",
180
+ "array",
181
+ "-",
182
+ "filter",
183
+ "to",
184
+ "range",
185
+ "are",
186
+ "*",
187
+ "partial0",
188
+ "2",
189
+ "if",
190
+ "reverse",
191
+ "that",
192
+ "elements",
193
+ "lambda1",
194
+ "==",
195
+ "an",
196
+ "arg2",
197
+ "values",
198
+ "slice",
199
+ "element",
200
+ "lambda2",
201
+ "deref",
202
+ "you",
203
+ "partial1",
204
+ "e",
205
+ "find",
206
+ "your",
207
+ "task",
208
+ "compute",
209
+ "among",
210
+ "from",
211
+ "consider",
212
+ "first",
213
+ "than",
214
+ "value",
215
+ "/",
216
+ "what",
217
+ "arrays",
218
+ "with",
219
+ "<",
220
+ "length",
221
+ ">",
222
+ "be",
223
+ "min",
224
+ "end",
225
+ "sum",
226
+ "one",
227
+ "head",
228
+ "f",
229
+ "by",
230
+ "combine",
231
+ "segment",
232
+ "coordinates",
233
+ "not",
234
+ "string",
235
+ "is_prime",
236
+ "false",
237
+ "||",
238
+ "at",
239
+ "10",
240
+ "half",
241
+ "position",
242
+ "self",
243
+ "subsequence",
244
+ "after",
245
+ "such",
246
+ "max",
247
+ "prime",
248
+ "sort",
249
+ "let",
250
+ "%",
251
+ "longest",
252
+ "inclusive",
253
+ "which",
254
+ "invoke1",
255
+ "1000000000",
256
+ "all",
257
+ "positions",
258
+ "!",
259
+ "square",
260
+ "its",
261
+ "has",
262
+ "reversed",
263
+ "another",
264
+ "less",
265
+ "each",
266
+ "\"\"",
267
+ "order",
268
+ "largest",
269
+ "maximum",
270
+ "g",
271
+ "last",
272
+ "smallest",
273
+ "times",
274
+ "strictly",
275
+ "40",
276
+ "smaller",
277
+ "indexes",
278
+ "str_concat",
279
+ "strlen",
280
+ "two",
281
+ "starting",
282
+ "<=",
283
+ "on",
284
+ "greater",
285
+ "how",
286
+ "many",
287
+ "int-deref",
288
+ "prefix",
289
+ "bigger",
290
+ "only",
291
+ "str_split",
292
+ "\" \"",
293
+ "str_index",
294
+ "can",
295
+ "plus",
296
+ "squared",
297
+ "product",
298
+ "strings",
299
+ "floor",
300
+ "sqrt",
301
+ "before",
302
+ "it",
303
+ "concatenation",
304
+ "index",
305
+ "as",
306
+ "define",
307
+ "multiplied",
308
+ "biggest",
309
+ "rounded",
310
+ "down",
311
+ "string[]",
312
+ "equal",
313
+ "integer",
314
+ "also",
315
+ "based",
316
+ "sorting",
317
+ "replace",
318
+ "becomes",
319
+ "single",
320
+ "digit",
321
+ "characters",
322
+ "keeping",
323
+ "including",
324
+ "h",
325
+ "larger",
326
+ "written",
327
+ "divisible",
328
+ "previous",
329
+ "subarray",
330
+ "mininum",
331
+ "second",
332
+ "middle",
333
+ "same",
334
+ "th",
335
+ "median",
336
+ "till",
337
+ "integers",
338
+ "sequence",
339
+ "for",
340
+ "indices",
341
+ "between",
342
+ "when",
343
+ "doubled",
344
+ "ending",
345
+ "even",
346
+ "multiply",
347
+ "squares",
348
+ "fibonacci",
349
+ "exclusive",
350
+ "odd",
351
+ "keep",
352
+ "whether",
353
+ "minimum",
354
+ "except",
355
+ "letters",
356
+ "appearing",
357
+ "letter",
358
+ "consecutive",
359
+ "character",
360
+ "factorial",
361
+ "chosen",
362
+ "start",
363
+ "begin",
364
+ "themselves",
365
+ "\"z\"",
366
+ "str_min",
367
+ "remove",
368
+ "present",
369
+ "exist",
370
+ "appear",
371
+ "starts",
372
+ "i",
373
+ "located",
374
+ "true",
375
+ "&&",
376
+ "found",
377
+ "discarding",
378
+ "is_sorted",
379
+ "removing",
380
+ "do",
381
+ "increasing",
382
+ "exceed",
383
+ "ascending",
384
+ "difference",
385
+ "decremented",
386
+ "existing",
387
+ "alphabetically",
388
+ "words",
389
+ "added",
390
+ "incremented",
391
+ "backwards",
392
+ "individual",
393
+ "lexicographically",
394
+ "separate",
395
+ "abbreviation",
396
+ "str_max",
397
+ "increment",
398
+ "consisting",
399
+ "equals",
400
+ "having",
401
+ "discard",
402
+ "descending",
403
+ "decreasing",
404
+ "sorted",
405
+ "being",
406
+ "where",
407
+ "right",
408
+ "there",
409
+ "ordinal",
410
+ "have",
411
+ "s",
412
+ "going",
413
+ "'",
414
+ "add",
415
+ "space",
416
+ "decrement",
417
+ "those",
418
+ "whitespaces",
419
+ "spaces",
420
+ "subtract",
421
+ "remaining",
422
+ "following",
423
+ "or",
424
+ "out",
425
+ "ordered",
426
+ "minimal",
427
+ "itself",
428
+ "symmetric",
429
+ "read",
430
+ "increases",
431
+ "word",
432
+ "immidiately",
433
+ "excluding",
434
+ "j",
435
+ "omitting",
436
+ "reads",
437
+ "maximal",
438
+ ">=",
439
+ "compare",
440
+ "form",
441
+ "absent",
442
+ "missing",
443
+ "cannot",
444
+ "whose",
445
+ "count",
446
+ "lowest",
447
+ "both",
448
+ "ends",
449
+ "beginning",
450
+ "left",
451
+ "mean",
452
+ "average",
453
+ "obtained",
454
+ "writing",
455
+ "result",
456
+ "joining",
457
+ "together",
458
+ "increase",
459
+ "highest",
460
+ "comparing",
461
+ "forms",
462
+ "avg",
463
+ "outside",
464
+ "positive",
465
+ "summed",
466
+ "belonging",
467
+ "lexicographical",
468
+ "rest",
469
+ "belong",
470
+ "inclucing",
471
+ "lexical",
472
+ "alphabetical",
473
+ "dictionary",
474
+ "k",
475
+ "negative",
476
+ "lexicographic",
477
+ "represents",
478
+ "delete",
479
+ "non",
480
+ "l",
481
+ "erase",
482
+ "m",
483
+ "comes",
484
+ "up",
485
+ "comparison",
486
+ "during",
487
+ "'s value is the largest inclusive, which is strictly less than maximum element in numbers from 1 to the element in `a` which'",
488
+ "'s value is the biggest (inclusive), which is strictly less than maximum element of range from 1 to the element in `a` which'",
489
+ "'s value is the highest, which is strictly less than maximum element among sequence of digits of the element in `a` which'"]
490
+
491
+
492
+ if __name__ == "__main__":
493
+ #g = Grammar.uniform(deepcoderPrimitives())
494
+
495
+ g = Grammar.fromProductions(algolispProductions(), logVariable=.9)
496
+
497
+ #p=Program.parse("(lambda (fn_call filter (list_add_symbol (lambda1_call == (list_add_symbol 1 (list_init_symbol (fn_call mod ( list_add_symbol 2 (list_init_symbol arg1)) ))) ) (list_init_symbol $0)) )")
498
+ p=Program.parse("(lambda (fn_call filter (list_add_symbol (lambda1_call eq (list_add_symbol (symbol_constant 1) (list_init_symbol (fn_call mod ( list_add_symbol (symbol_constant 2) (list_init_symbol (symbol_constant arg1))) ))) ) (list_init_symbol (symbol_constant $0)))))")
499
+
500
+ print(p)
501
+
502
+ #tree = p.evaluate(["a"])
503
+ tree = p.evaluate([])
504
+ print(tree("a"))
505
+
506
+ #
507
+
508
+
dreamcoder/domains/misc/deepcoderPrimitives.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.program import Primitive, prettyProgram
2
+ from dreamcoder.grammar import Grammar
3
+ from dreamcoder.type import tlist, tint, arrow, baseType #, t0, t1, t2
4
+
5
+ #from functools import reduce
6
+
7
+
8
+ #todo
9
+ int_to_int = baseType("int_to_int")
10
+ int_to_bool = baseType("int_to_bool")
11
+ int_to_int_to_int = baseType("int_to_int_to_int")
12
+
13
+
14
+ #deepcoderPrimitives
15
+ Null = 300 #or perhaps something else, like "an integer outside the working range"?
16
+
17
+ def _head(xs): return xs[0] if len(xs)>0 else Null
18
+ def _tail(xs): return xs[-1] if len(xs)>0 else Null
19
+ def _take(n): return lambda xs: xs[:n]
20
+ def _drop(n): return lambda xs: xs[n:]
21
+ def _access(n): return lambda xs: xs[n] if n>=0 and len(xs)>n else Null
22
+ def _minimum(xs): return min(xs) if len(xs)>0 else Null
23
+ def _maximum(xs): return max(xs) if len(xs)>0 else Null
24
+ def _reverse(xs): return list(reversed(xs))
25
+ def _sort(xs): return sorted(xs)
26
+ def _sum(xs): return sum(xs)
27
+
28
+ #higher order:
29
+ def _map(f): return lambda l: list(map(f, l))
30
+ def _filter(f): return lambda l: list(filter(f, l))
31
+ def _count(f): return lambda l: len([x for x in l if f(x)])
32
+ def _zipwith(f): return lambda xs: lambda ys: [f(x)(y) for (x, y) in zip(xs, ys)]
33
+ def _scanl1(f):
34
+ def _inner(xs):
35
+ ys = []
36
+ if len(xs) > 0:
37
+ ys.append(xs[0])
38
+ for i in range(1, len(xs)):
39
+ ys.append( f(ys[i-1])(xs[i]))
40
+ return ys
41
+ return _inner
42
+
43
+ #int to int:
44
+ def _succ(x): return x+1
45
+ def _pred(x): return x-1
46
+ def _double(x): return x*2
47
+ def _half(x): return int(x/2)
48
+ def _negate(x): return -x
49
+ def _square(x): return x**2
50
+ def _triple(x): return x*3
51
+ def _third(x): return int(x/3)
52
+ def _quad(x): return x*4
53
+ def _quarter(x): return int(x/4)
54
+
55
+ #int to bool:
56
+ def _pos(x): return x>0
57
+ def _neg(x): return x<0
58
+ def _even(x): return x%2==0
59
+ def _odd(x): return x%2==1
60
+
61
+ #int to int to int:
62
+ def _add(x): return lambda y: x + y
63
+ def _sub(x): return lambda y: x - y
64
+ def _mult(x): return lambda y: x * y
65
+ def _min(x): return lambda y: _minimum([x,y])
66
+ def _max(x): return lambda y: _maximum([x,y])
67
+
68
+ def deepcoderPrimitives():
69
+ return [
70
+ Primitive("HEAD", arrow(tlist(tint), tint), _head),
71
+ Primitive("LAST", arrow(tlist(tint), tint), _tail),
72
+ Primitive("TAKE", arrow(tint, tlist(tint), tlist(tint)), _take),
73
+ Primitive("DROP", arrow(tint, tlist(tint), tlist(tint)), _drop),
74
+ Primitive("ACCESS", arrow(tint, tlist(tint), tint), _access),
75
+ Primitive("MINIMUM", arrow(tlist(tint), tint), _minimum),
76
+ Primitive("MAXIMUM", arrow(tlist(tint), tint), _maximum),
77
+ Primitive("REVERSE", arrow(tlist(tint), tlist(tint)), _reverse),
78
+ Primitive("SORT", arrow(tlist(tint), tlist(tint)), _sort),
79
+ Primitive("SUM", arrow(tlist(tint), tint), _sum)
80
+ ] + [
81
+ Primitive("MAP", arrow(int_to_int, tlist(tint), tlist(tint)), _map), #is this okay???
82
+ Primitive("FILTER", arrow(int_to_bool, tlist(tint), tlist(tint)), _filter), #is this okay???
83
+ Primitive("COUNT", arrow(int_to_bool, tlist(tint), tint), _count), #is this okay???
84
+ Primitive("ZIPWITH", arrow(int_to_int_to_int, tlist(tint), tlist(tint), tlist(tint)), _zipwith), #is this okay???
85
+ Primitive("SCANL1", arrow(int_to_int_to_int, tlist(tint), tlist(tint)), _scanl1), #is this okay???
86
+ ] + [
87
+ Primitive("INC", int_to_int, _succ),
88
+ Primitive("DEC", int_to_int, _pred),
89
+ Primitive("SHL", int_to_int, _double),
90
+ Primitive("SHR", int_to_int, _half),
91
+ Primitive("doNEG", int_to_int, _negate),
92
+ Primitive("SQR", int_to_int, _square),
93
+ Primitive("MUL3", int_to_int, _triple),
94
+ Primitive("DIV3", int_to_int, _third),
95
+ Primitive("MUL4", int_to_int, _quad),
96
+ Primitive("DIV4", int_to_int, _quarter),
97
+ ] + [
98
+ Primitive("isPOS", int_to_bool, _pos),
99
+ Primitive("isNEG", int_to_bool, _neg),
100
+ Primitive("isEVEN", int_to_bool, _even),
101
+ Primitive("isODD", int_to_bool, _odd),
102
+ ] + [
103
+ Primitive("+", int_to_int_to_int, _add),
104
+ Primitive("-", int_to_int_to_int, _sub),
105
+ Primitive("*", int_to_int_to_int, _mult),
106
+ Primitive("MIN", int_to_int_to_int, _min),
107
+ Primitive("MAX", int_to_int_to_int, _max)
108
+ ]
109
+
110
+ def OldDeepcoderPrimitives():
111
+ return [
112
+ Primitive("head", arrow(tlist(tint), tint), _head),
113
+ Primitive("tail", arrow(tlist(tint), tint), _tail),
114
+ Primitive("take", arrow(tint, tlist(tint), tlist(tint)), _take),
115
+ Primitive("drop", arrow(tint, tlist(tint), tlist(tint)), _drop),
116
+ Primitive("access", arrow(tint, tlist(tint), tint), _access),
117
+ Primitive("minimum", arrow(tlist(tint), tint), _minimum),
118
+ Primitive("maximum", arrow(tlist(tint), tint), _maximum),
119
+ Primitive("reverse", arrow(tlist(tint), tlist(tint)), _reverse),
120
+ Primitive("sort", arrow(tlist(tint), tlist(tint)), _sort),
121
+ Primitive("sum", arrow(tlist(tint), tint), _sum)
122
+ ] + [
123
+ Primitive("map", arrow(int_to_int, tlist(tint), tlist(tint)), _map), #is this okay???
124
+ Primitive("filter_int", arrow(int_to_bool, tlist(tint), tlist(tint)), _filter), #is this okay???
125
+ Primitive("count", arrow(int_to_bool, tlist(tint), tint), _count), #is this okay???
126
+ Primitive("zipwith", arrow(int_to_int_to_int, tlist(tint), tlist(tint), tlist(tint)), _zipwith), #is this okay???
127
+ Primitive("scanl1", arrow(int_to_int_to_int, tlist(tint), tlist(tint)), _scanl1), #is this okay???
128
+ # ] + [
129
+ # Primitive("succ", arrow(tint, tint), _succ),
130
+ # Primitive("pred", arrow(tint, tint), _pred),
131
+ # Primitive("double", arrow(tint, tint), _double),
132
+ # Primitive("half", arrow(tint, tint), _half),
133
+ # Primitive("neg", arrow(tint, tint), _neg),
134
+ # Primitive("square", arrow(tint, tint), _square),
135
+ # Primitive("triple", arrow(tint, tint), _triple),
136
+ # Primitive("third", arrow(tint, tint), _third),
137
+ # Primitive("quad", arrow(tint, tint), _quad),
138
+ # Primitive("quarter", arrow(tint, tint), _quarter),
139
+ # ] + [
140
+ # Primitive("pos", arrow(tint, tbool), _pos),
141
+ # Primitive("neg", arrow(tint, tbool), _neg),
142
+ # Primitive("even", arrow(tint, tbool), _even),
143
+ # Primitive("odd", arrow(tint, tbool), _odd),
144
+ # ] + [
145
+ # Primitive("add", arrow(tint, tint, tint), _add),
146
+ # Primitive("sub", arrow(tint, tint, tint), _sub),
147
+ # Primitive("mult", arrow(tint, tint, tint), _mult),
148
+ # Primitive("min", arrow(tint, tint, tint), _min),
149
+ # Primitive("max", arrow(tint, tint, tint), _max)
150
+ ] + [
151
+ Primitive("succ_fn", int_to_int, _succ),
152
+ Primitive("pred_fn", int_to_int, _pred),
153
+ Primitive("double_fn", int_to_int, _double),
154
+ Primitive("half_fn", int_to_int, _half),
155
+ Primitive("negate_fn", int_to_int, _negate),
156
+ Primitive("square_fn", int_to_int, _square),
157
+ Primitive("triple_fn", int_to_int, _triple),
158
+ Primitive("third_fn", int_to_int, _third),
159
+ Primitive("quad_fn", int_to_int, _quad),
160
+ Primitive("quarter_fn", int_to_int, _quarter),
161
+ ] + [
162
+ Primitive("pos_fn", int_to_bool, _pos),
163
+ Primitive("neg_fn", int_to_bool, _neg),
164
+ Primitive("even_fn", int_to_bool, _even),
165
+ Primitive("odd_fn", int_to_bool, _odd),
166
+ ] + [
167
+ Primitive("add_fn", int_to_int_to_int, _add),
168
+ Primitive("sub_fn", int_to_int_to_int, _sub),
169
+ Primitive("mult_fn", int_to_int_to_int, _mult),
170
+ Primitive("min_fn", int_to_int_to_int, _min),
171
+ Primitive("max_fn", int_to_int_to_int, _max)
172
+ ]
173
+
174
+ def deepcoderProductions():
175
+ return [(0.0, prim) for prim in deepcoderPrimitives()]
176
+
177
+ def flatten_program(p):
178
+ string = p.show(False)
179
+ num_inputs = string.count('lambda')
180
+ string = string.replace('lambda', '')
181
+ string = string.replace('(', '')
182
+ string = string.replace(')', '')
183
+ #remove '_fn' (optional)
184
+ for i in range(num_inputs):
185
+ string = string.replace('$' + str(num_inputs-i-1),'input_' + str(i))
186
+ string = string.split(' ')
187
+ string = list(filter(lambda x: x is not '', string))
188
+ return string
189
+
190
+ if __name__ == "__main__":
191
+ #g = Grammar.uniform(deepcoderPrimitives())
192
+ g = Grammar.fromProductions(deepcoderProductions(), logVariable=.9)
193
+ request = arrow(tlist(tint), tint, tint)
194
+ p = g.sample(request)
195
+ print("request:", request)
196
+ print("program:")
197
+ print(prettyProgram(p))
198
+ print("flattened_program:")
199
+ flat = flatten_program(p)
200
+ print(flat)
201
+
202
+ #robustfill output = names from productions + input_0-2 or 3
203
+
204
+
205
+
206
+
207
+
208
+ # # with open("/home/ellisk/om/ec/experimentOutputs/list_aic=1.0_arity=3_ET=1800_expandFrontier=2.0_it=4_likelihoodModel=all-or-nothing_MF=5_baseline=False_pc=10.0_L=1.0_K=5_rec=False.pickle", "rb") as handle:
209
+ # # b = pickle.load(handle).grammars[-1]
210
+ # # print b
211
+
212
+ # p = Program.parse(
213
+ # "(lambda (lambda (lambda (if (empty? $0) empty (cons (+ (car $1) (car $0)) ($2 (cdr $1) (cdr $0)))))))")
214
+ # t = arrow(tlist(tint), tlist(tint), tlist(tint)) # ,tlist(tbool))
215
+ # print(g.logLikelihood(arrow(t, t), p))
216
+ # assert False
217
+ # print(b.logLikelihood(arrow(t, t), p))
218
+
219
+ # # p = Program.parse("""(lambda (lambda
220
+ # # (unfold 0
221
+ # # (lambda (+ (index $0 $2) (index $0 $1)))
222
+ # # (lambda (1+ $0))
223
+ # # (lambda (eq? $0 (length $1))))))
224
+ # # """)
225
+ # p = Program.parse("""(lambda (lambda
226
+ # (map (lambda (+ (index $0 $2) (index $0 $1))) (range (length $0)) )))""")
227
+ # # .replace("unfold", "#(lambda (lambda (lambda (lambda (fix1 $0 (lambda (lambda (#(lambda (lambda (lambda (if $0 empty (cons $1 $2))))) ($1 ($3 $0)) ($4 $0) ($5 $0)))))))))").\
228
+ # # replace("length", "#(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ ($1 (cdr $0)) 1))))))").\
229
+ # # replace("forloop", "(#(lambda (lambda (lambda (lambda (fix1 $0 (lambda (lambda (#(lambda (lambda (lambda (if $0 empty (cons $1 $2))))) ($1 ($3 $0)) ($4 $0) ($5 $0))))))))) (lambda (#(eq? 0) $0)) $0 (lambda (#(lambda (- $0 1)) $0)))").\
230
+ # # replace("inc","#(lambda (+ $0 1))").\
231
+ # # replace("drop","#(lambda (lambda (fix2 $0 $1 (lambda (lambda (lambda (if
232
+ # # (#(eq? 0) $1) $0 (cdr ($2 (- $1 1) $0)))))))))"))
233
+ # print(p)
234
+ # print(g.logLikelihood(t, p))
235
+ # assert False
236
+
237
+ # print("??")
238
+ # p = Program.parse(
239
+ # "#(lambda (#(lambda (lambda (lambda (fix1 $0 (lambda (lambda (if (empty? $0) $3 ($4 (car $0) ($1 (cdr $0)))))))))) (lambda $1) 1))")
240
+ # for j in range(10):
241
+ # l = list(range(j))
242
+ # print(l, p.evaluate([])(lambda x: x * 2)(l))
243
+ # print()
244
+ # print()
245
+
246
+ # print("multiply")
247
+ # p = Program.parse(
248
+ # "(lambda (lambda (lambda (if (eq? $0 0) 0 (+ $1 ($2 $1 (- $0 1)))))))")
249
+ # print(g.logLikelihood(arrow(arrow(tint, tint, tint), tint, tint, tint), p))
250
+ # print()
251
+
252
+ # print("take until 0")
253
+ # p = Program.parse("(lambda (lambda (if (eq? $1 0) empty (cons $1 $0))))")
254
+ # print(g.logLikelihood(arrow(tint, tlist(tint), tlist(tint)), p))
255
+ # print()
256
+
257
+ # print("countdown primitive")
258
+ # p = Program.parse(
259
+ # "(lambda (lambda (if (eq? $0 0) empty (cons (+ $0 1) ($1 (- $0 1))))))")
260
+ # print(
261
+ # g.logLikelihood(
262
+ # arrow(
263
+ # arrow(
264
+ # tint, tlist(tint)), arrow(
265
+ # tint, tlist(tint))), p))
266
+ # print(_fix(9)(p.evaluate([])))
267
+ # print("countdown w/ better primitives")
268
+ # p = Program.parse(
269
+ # "(lambda (lambda (if (eq0 $0) empty (cons (+1 $0) ($1 (-1 $0))))))")
270
+ # print(
271
+ # g.logLikelihood(
272
+ # arrow(
273
+ # arrow(
274
+ # tint, tlist(tint)), arrow(
275
+ # tint, tlist(tint))), p))
276
+
277
+ # print()
278
+
279
+ # print("prepend zeros")
280
+ # p = Program.parse(
281
+ # "(lambda (lambda (lambda (if (eq? $1 0) $0 (cons 0 ($2 (- $1 1) $0))))))")
282
+ # print(
283
+ # g.logLikelihood(
284
+ # arrow(
285
+ # arrow(
286
+ # tint,
287
+ # tlist(tint),
288
+ # tlist(tint)),
289
+ # tint,
290
+ # tlist(tint),
291
+ # tlist(tint)),
292
+ # p))
293
+ # print()
294
+ # assert False
295
+
296
+ # p = Program.parse(
297
+ # "(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ 1 ($1 (cdr $0))))))))")
298
+ # print(p.evaluate([])(list(range(17))))
299
+ # print(g.logLikelihood(arrow(tlist(tbool), tint), p))
300
+
301
+ # p = Program.parse(
302
+ # "(lambda (lambda (if (empty? $0) 0 (+ 1 ($1 (cdr $0))))))")
303
+ # print(
304
+ # g.logLikelihood(
305
+ # arrow(
306
+ # arrow(
307
+ # tlist(tbool), tint), arrow(
308
+ # tlist(tbool), tint)), p))
309
+
310
+ # p = Program.parse(
311
+ # "(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ (car $0) ($1 (cdr $0))))))))")
312
+
313
+ # print(p.evaluate([])(list(range(4))))
314
+ # print(g.logLikelihood(arrow(tlist(tint), tint), p))
315
+
316
+ # p = Program.parse(
317
+ # "(lambda (lambda (if (empty? $0) 0 (+ (car $0) ($1 (cdr $0))))))")
318
+ # print(p)
319
+ # print(
320
+ # g.logLikelihood(
321
+ # arrow(
322
+ # arrow(
323
+ # tlist(tint),
324
+ # tint),
325
+ # tlist(tint),
326
+ # tint),
327
+ # p))
328
+
329
+ # print("take")
330
+ # p = Program.parse(
331
+ # "(lambda (lambda (lambda (if (eq? $1 0) empty (cons (car $0) ($2 (- $1 1) (cdr $0)))))))")
332
+ # print(p)
333
+ # print(
334
+ # g.logLikelihood(
335
+ # arrow(
336
+ # arrow(
337
+ # tint,
338
+ # tlist(tint),
339
+ # tlist(tint)),
340
+ # tint,
341
+ # tlist(tint),
342
+ # tlist(tint)),
343
+ # p))
344
+ # assert False
345
+
346
+ # print(p.evaluate([])(list(range(4))))
347
+ # print(g.logLikelihood(arrow(tlist(tint), tlist(tint)), p))
348
+
349
+ # p = Program.parse(
350
+ # """(lambda (fix (lambda (lambda (match $0 0 (lambda (lambda (+ $1 ($3 $0))))))) $0))""")
351
+ # print(p.evaluate([])(list(range(4))))
352
+ # print(g.logLikelihood(arrow(tlist(tint), tint), p))
dreamcoder/domains/misc/napsPrimitives.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #napsPrimitives.py
2
+ from dreamcoder.program import Primitive, prettyProgram
3
+ from dreamcoder.grammar import Grammar
4
+ from dreamcoder.type import tlist, tint, arrow, baseType #, t0, t1, t2
5
+
6
+ #from functools import reduce
7
+
8
+
9
+ #types
10
+ PROGRAM = baseType("PROGRAM")
11
+
12
+ RECORD = baseType("RECORD")
13
+ FUNC = baseType("FUNC")
14
+
15
+ VAR = baseType("VAR")
16
+ STMT = baseType("STMT")
17
+ EXPR = baseType("EXPR")
18
+ ASSIGN = baseType("ASSIGN")
19
+ LHS = baseType("LHS")
20
+ IF = baseType("IF")
21
+ FOREACH = baseType("FOREACH")
22
+ WHILE = baseType("WHILE")
23
+ BREAK = baseType("BREAK")
24
+ CONTINUE = baseType("CONTINUE")
25
+ RETURN = baseType("RETURN")
26
+ NOOP = baseType("NOOP")
27
+ FIELD = baseType("FIELD")
28
+ CONSTANT = baseType("CONSTANT")
29
+ INVOKE = baseType("INVOKE")
30
+ TERNARY = baseType("TERNARY")
31
+ CAST = baseType("CAST")
32
+ TYPE = baseType("TYPE")
33
+
34
+ #other types
35
+ function_name = baseType("function_name")
36
+ field_name = baseType("field_name")
37
+ name = baseType("name") # for records and functions
38
+ value = baseType("value")
39
+
40
+ # definitions:
41
+
42
+ def _program(records): return lambda funcs: {'types': records, 'funcs': funcs}
43
+ # record
44
+ def _func(string): return lambda tp: lambda name: lambda vars1: lambda vars2: lambda stmts: [string, tp, name, vars1, vars2, stmts]
45
+ def _var(tp): return lambda name: ['var', tp, name]
46
+ # stmt
47
+ # expr
48
+ def _assign(tp): return lambda lhs: lambda expr: ['assign', tp, lhs, expr]
49
+ # lhs
50
+ def _if(tp): return lambda expr: lambda stmts1: lambda stmts2: ['if', tp, expr, stmts1, stmts2] # TODO
51
+ def _foreach(tp): return lambda var: lambda expr: lambda stmts: ['foreach', tp, expr, stmts] # TODO
52
+ def _while(tp): return lambda expr: lambda stmts1: lambda stmts1: ['while', tp, expr, stmts1, stmts2] # or: ['while', tp, expr, [stmts1], [stmts2]] #TODO
53
+ # break
54
+ # continue
55
+ def _return(tp): return lambda expr: ['return', tp, expr]
56
+ # noop
57
+ def _field(tp): return lambda expr: lambda field_name: ['field', tp, expr, field_name]
58
+ def _constant(tp): return lambda value: ['val', tp, value] #TODO deal with value
59
+ def _invoke(tp): return lambda function_name: lambda exprs: ['invoke', tp, function_name, exprs] # TODO, deal with fn_name and lists
60
+ def _ternary(tp): return lambda expr1: lambda expr2: lambda expr3: ['?:', tp, expr1, expr2, expr3]
61
+ def _cast(tp): return lambda expr: ['cast', tp, expr]
62
+
63
+ # types:
64
+
65
+ # TODO: deal with lists - x
66
+ # TODO: deal with names
67
+ # TODO: deal with values - x
68
+
69
+ # TODO: deal with the program/record __main__ and __globals__ stuff
70
+
71
+
72
+
73
+ def napsPrimitives():
74
+ return [
75
+ Primitive("program", arrow(tlist(RECORD), tlist(FUNC), PROGRAM), _program), # TODO
76
+ # RECORD
77
+ Primitive("func", arrow(TYPE, name, tlist(VAR), tlist(VAR), tlist(VAR), tlist(STMT)), _func('func')), # TODO
78
+ Primitive("ctor", arrow(TYPE, name, tlist(VAR), tlist(VAR), tlist(VAR), tlist(STMT)), _func('ctor')),
79
+ Primitive("var", arrow(TYPE, name, VAR), _var)
80
+ ] + [
81
+ # STMT ::= EXPR | IF | FOREACH | WHILE | BREAK | CONTINUE | RETURN | NOOP
82
+ Primitive("stmt_expr", arrow(EXPR, STMT), lambda x: x),
83
+ Primitive("stmt_if", arrow(IF, STMT), lambda x: x),
84
+ Primitive("stmt_foreach", arrow(FOREACH, STMT), lambda x: x),
85
+ Primitive("stmt_while", arrow(WHILE, STMT), lambda x: x),
86
+ Primitive("stmt_break", arrow(BREAK, STMT), lambda x: x),
87
+ Primitive("stmt_continue", arrow(CONTINUE, STMT), lambda x: x),
88
+ Primitive("stmt_return", arrow(RETURN, STMT), lambda x: x),
89
+ Primitive("stmt_noop", arrow(NOOP, STMT), lambda x: x)
90
+ ] + [
91
+ # EXPR ::= ASSIGN | VAR | FIELD | CONSTANT | INVOKE | TERNARY | CAST
92
+ Primitive("expr_assign", arrow(ASSIGN, EXPR), lambda x: x),
93
+ Primitive("expr_var", arrow(VAR, EXPR), lambda x: x),
94
+ Primitive("expr_field", arrow(FIELD, EXPR), lambda x: x),
95
+ Primitive("expr_constant", arrow(CONSTANT, EXPR), lambda x: x),
96
+ Primitive("expr_invoke", arrow(INVOKE, EXPR), lambda x: x),
97
+ Primitive("expr_ternary", arrow(TERNARY, EXPR), lambda x: x),
98
+ Primitive("expr_cast", arrow(CAST, EXPR), lambda x: x)
99
+ ] + [
100
+ Primitive("assign", arrow(TYPE, LHS, EXPR, ASSIGN), _assign)
101
+ ] + [
102
+ # LHS ::= VAR | FIELD | INVOKE
103
+ Primitive("lhs_var", arrow(VAR, LHS), lambda x: x),
104
+ Primitive("lhs_field", arrow(FIELD, LHS), lambda x: x),
105
+ Primitive("lhs_invoke", arrow(INVOKE, LHS), lambda x: x)
106
+ ] + [
107
+ Primitive("if", arrow(TYPE, EXPR, tlist(STMT), tlist(STMT), IF), _if),
108
+ Primitive("foreach", arrow(TYPE, VAR, EXPR, tlist(STMT), FOREACH), _foreach),
109
+ Primitive("while", arrow(TYPE, EXPR, tlist(STMT), tlist(STMT), WHILE), _while),
110
+ Primitive("break", arrow(TYPE, BREAK), lambda tp: ['break', tp]),
111
+ Primitive("continue", arrow(TYPE, CONTINUE), lambda tp: ['continue', tp]),
112
+ Primitive("return", arrow(TYPE, EXPR, RETURN), _return),
113
+ Primitive("noop", NOOP, ['noop']),
114
+ Primitive("field", arrow(TYPE, EXPR, field_name, FIELD), _field), # TODO
115
+ Primitive("constant", arrow(TYPE, value, CONSTANT), _constant),
116
+ Primitive("invoke", arrow(TYPE, function_name, tlist(EXPR), INVOKE), _invoke), # TODO
117
+ Primitive("ternary", arrow(TYPE, EXPR, EXPR, EXPR, TERNARY), _ternary),
118
+ Primitive("cast", arrow(TYPE, EXPR, CAST), _cast)
119
+ ] + [
120
+ # below are TYPE:
121
+ Primitive("bool", TYPE, 'bool'),
122
+ Primitive("char", TYPE, 'char'),
123
+ Primitive("char*", TYPE, 'char*'),
124
+ Primitive("int", TYPE, 'int'),
125
+ Primitive("real", TYPE, 'real'),
126
+ Primitive("array", arrow(TYPE, TYPE), lambda tp: tp + '*'),
127
+ Primitive("set", arrow(TYPE, TYPE), lambda tp: tp + '%'),
128
+ Primitive("map", arrow(TYPE, TYPE, TYPE), lambda tp1: lambda tp2: '<'+tp1+'|'+tp2+'>'),
129
+ Primitive("record_name", TYPE, 'record_name#') # TODO
130
+ ] + [
131
+ #stuff about lists:
132
+ # STMTs, EXPRs, VARs, maybe Funcs and records
133
+ Primitive('list_init_stmt', arrow(STMT, tlist(STMT)), lambda stmt: [stmt]),
134
+ Primitive('list_add_stmt', arrow(STMT, tlist(STMT), tlist(STMT)), lambda stmt: lambda stmts: stmts + [stmt]),
135
+ Primitive('list_init_expr', arrow(EXPR, tlist(EXPR)), lambda expr: [expr]),
136
+ Primitive('list_add_expr', arrow(EXPR, tlist(EXPR), tlist(EXPR)), lambda expr: lambda exprs: exprs + [expr]),
137
+ Primitive('list_init_var', arrow(VAR, tlist(VAR)), lambda var: [var]),
138
+ Primitive('list_add_var', arrow(VAR, tlist(VAR), tlist(VAR)), lambda var: lambda _vars: _vars + [var])
139
+ ] + [
140
+ # value
141
+ Primitive('0', value, 0),
142
+ Primitive("1", value, "1"),
143
+ Primitive("-1", value, "-1")
144
+ # ...
145
+ ] + [
146
+ # function_name:
147
+ Primitive('+', function_name, '+'),
148
+ Primitive('&&', function_name, "&&"),
149
+ Primitive("!", function_name, "!"),
150
+ Primitive("!=", function_name, "!="),
151
+ Primitive("string_find", function_name,"string_find")
152
+ # ...
153
+ ] + [
154
+ # field_name:
155
+ Primitive('', field_name, '')
156
+ # ...
157
+ ] + [
158
+ #
159
+ Primitive(f'var{str(i)}', name, f'var{str(i)}') for i in range(12)
160
+ ]
161
+
162
+
163
+ #for first pass, can just hard code vars and maps n stuff
164
+
165
+ def ec_prog_to_uast(prog): # TODO
166
+ # ideally, just evaluate and then parse
167
+ uast = prog.evaluate([])
168
+ return uast
169
+
170
+ def deepcoderProductions():
171
+ return [(0.0, prim) for prim in deepcoderPrimitives()]
172
+
173
+ # def flatten_program(p):
174
+ # string = p.show(False)
175
+ # num_inputs = string.count('lambda')
176
+ # string = string.replace('lambda', '')
177
+ # string = string.replace('(', '')
178
+ # string = string.replace(')', '')
179
+ # #remove '_fn' (optional)
180
+ # for i in range(num_inputs):
181
+ # string = string.replace('$' + str(num_inputs-i-1),'input_' + str(i))
182
+ # string = string.split(' ')
183
+ # string = list(filter(lambda x: x is not '', string))
184
+ # return string
185
+
186
+ if __name__ == "__main__":
187
+ #g = Grammar.uniform(deepcoderPrimitives())
188
+ g = Grammar.fromProductions(deepcoderProductions(), logVariable=.9)
189
+ request = arrow(tlist(tint), tint, tint)
190
+ p = g.sample(request)
191
+ print("request:", request)
192
+ print("program:")
193
+ print(prettyProgram(p))
194
+ print("flattened_program:")
195
+ flat = flatten_program(p)
196
+ print(flat)
197
+
198
+
dreamcoder/domains/regex/__init__.py ADDED
File without changes
dreamcoder/domains/regex/groundtruthRegexes.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #dict of gt regexes
3
+
4
+ """
5
+ pre.create(".+"),
6
+ pre.create("\d+"),
7
+ pre.create("\w+"),
8
+ pre.create("\s+"),
9
+ pre.create("\\u+"),
10
+ pre.create("\l+")
11
+ """
12
+ gt_dict = {
13
+ 776: "JPC\\u\\u\\d+\\.png",
14
+ 922: "WHS\\d_\\d+",
15
+ 354: "\\u+",
16
+ 523: "(\\u)+|\\.",
17
+ 184: "\\.\\d+",
18
+ 501: "u\\d\\d",
19
+ 760: "\\u\\u",
20
+ 49: "(\\u)+\\u\\d?",
21
+ 732: "\\uR5\\d\\d",
22
+ 450: "-\\d(\\.(\\d)+)?",
23
+ 350: "\\u\\u",
24
+ 467: "hu\\d(\\d|\\u)+",
25
+ 622: "A(\\d|\\u)**",
26
+ 476: "\\u+",
27
+ 554: "\\u\\u",
28
+ 940: "\\u\\u?",
29
+ 496: "\\u\\u",
30
+ 369: "\\u\\u\\u",
31
+ 596: "\\u+",
32
+ 720: "\\(\\d\\d\\d\\) \\d\\d\\d-\\d\\d\\d\\d",
33
+ 53: "rec-\\d\\d\\d?-(org)|(dup-0)",
34
+ 150: "N\\d\\d",
35
+ 741: "#\\d\\d\\d",
36
+ 18: "A|C-\\d+-\\d+",
37
+ 589: "A(\\u|\\d)++",
38
+ 666: "\\(\\d\\d\\d\\) \\d\\d\\d-\\d\\d\\d\\d",
39
+ 581: "us13\\u\\d\\d",
40
+ 299: "E07000\\d\\d\\d",
41
+ 638: "\\l+\\d+\\l+\\d+",
42
+ 364: "\\u\\u",
43
+ 334: "-00:\\d\\d:\\d\\d.\\d",
44
+ 38: "SRX89\\d+",
45
+ 247: "'\\d\\d:\\d\\d:00'",
46
+ 506: "(S|H)\\d+",
47
+ 891: "(r|v)\\d?",
48
+ 911: "KW-\\d+",
49
+ 792: "\\d*\\u*",
50
+ 508: "N000\\d+",
51
+ 842: "-?\\d?\\d\\.\\d\\d%",
52
+ 200: "\\u\\u",
53
+ 694: "\\(\\d+\\)",
54
+ 210: "(\\d(\\.\\d)?)|(--)",
55
+ 298: "DS_25(\\u|\\d)+",
56
+ 668: "\\u+",
57
+ 939: "ms0\\d+",
58
+ 944: "\\u+\\d?",
59
+ 731: "ManH.0\\d\\d",
60
+ 229: "\\u+(-\\u+)?",
61
+ 28: "Y201\\d/\\d\\d\\d\\d",
62
+ 374: "q000\\d(_000\\d)?",
63
+ 819: "\\d*\\l*\\d*",
64
+ 516: "-122.3\\d+",
65
+ 417: "\\u\\uT\\uB",
66
+ 660: "ENGL?\\d\\d\\d",
67
+ 585: "M?\\u+",
68
+ 325: "BUS M \\d\\d\\d.*",
69
+ 823: "\\u\\u\\u",
70
+ 515: "L|\\u - (\\?\\?)|(\\d?\\d\\.\\d lbs\\.)",
71
+ 864: "\\u+",
72
+ 359: "MAM\\.OSBS\\.201\\d\\.\\d\\d",
73
+ 594: "(\\u|\\d)+( (\\u|\\d)+)*",
74
+ 788: "-\\d(,\\d+)?",
75
+ 188: "cat\\. \\d\\d",
76
+ 355: ".+",
77
+ 799: "\\u\\d\\d",
78
+ 902: "\\u\\d\\d",
79
+ 920: "A\\.\\d\\d",
80
+ 330: "Resp\\d\\d",
81
+ 396: "\\u+(( |/)\\u+)?",
82
+ 393: "US $ \\d\\.\\d\\d",
83
+ 680: "Z:-?0\\.\\d\\d",
84
+ 744: "t1_cv(\\l|\\d)+",
85
+ 461: "(\\u|\\l)+\\d+",
86
+ 631: "$\\d+\\.\\d+",
87
+ 195: "(OLE)?\\d+",
88
+ 693: "\\u",
89
+ 577: "EFO_000\\d+",
90
+ 392: "$\\d+(,\\d\\d\\d)*\\.00",
91
+ 688: "\\u+( \\u+)*",
92
+ 816: "\\u\\u\\u",
93
+ 489: "UK\\u\\d",
94
+ 251: "\\l\\l\\l",
95
+ 653: "C\\d+",
96
+ 769: "(\\u|\\l|\\d|-)+\\d+",
97
+ 991: "Q\\d-201\\d",
98
+ 342: "\\u\\u\\d\\d\\d\\d",
99
+ 308: "\\u\\u\\u\\u",
100
+ 136: "IMPC_\\u\\u\\u_\\d\\d\\d_\\d\\d\\d",
101
+ 327: "#\\d+((/|-)\\d+)*",
102
+ 981: "\\u\\u\\u",
103
+ 892: "(.|\\l)*",
104
+ 375: "P\\u\\.\\d\\d\\d\\d\\.\\d\\d\\d",
105
+ 499: "A000\\d+",
106
+ 474: "\\u+",
107
+ 50: "V06\\d+",
108
+ 381: "F?\\d+",
109
+ 883: "-79.\\d+",
110
+ 173: "(\\u|\\l)+\\d+",
111
+ 147: "\\u\\u\\u-\\u\\u\\u",
112
+ 419: "\\u\\u",
113
+ 961: "-?\\d\\.\\d*",
114
+ 148: "Q\\d\\d",
115
+ 975: "(\\d|\\u)+",
116
+ 79: "\\d+(,\\d\\d\\d)+",
117
+ 775: "\\u\\l\\l \\d+ \\d\\d\\d\\d",
118
+ 774: "FOS\\d\\d+",
119
+ 561: ".+",
120
+ 509: "S000\\d+",
121
+ 494: "S1900\\d+",
122
+ 119: "$\\d\\d(,\\d\\d\\d)+",
123
+ 29: "(\\u|\\l|\\d)+",
124
+ 121: "(\\d|\\u|\\.|/|\\(|\\))+",
125
+ 61: "R \\d\\d\\d.\\d\\d",
126
+ 871: "-0.7\\d+",
127
+ 639: "\\u+?\\d+",
128
+ 729: "COMISARIA \\d\\d",
129
+ 193: "\\u\\d\\d",
130
+ 752: "(.*|\\u\\.?)+",
131
+ 17: "$\\d.\\d\\d",
132
+ 914: "R\\d\\d\\d\\d",
133
+ 510: "P\\d000\\d\\d\\d\\d",
134
+ 443: "(W|L) \\d-\\d+",
135
+ 20: "MDEL\\d\\d?\\.\\d\\l",
136
+ 64: "c04p0100(\\l|\\d)",
137
+ 301: "(\\u|\\d)+(-(\\u|\\d)+)*",
138
+ 664: "N\\d",
139
+ 493: "[0\\.0\\d+]",
140
+ 765: "-?\\d\\.\\d+( \\(0\\.\\d+\\))?"
141
+
142
+
143
+ }
144
+ badRegexTasks = {
145
+ "Data column no. 922",
146
+ "Data column no. 184",
147
+ "Data column no. 467",
148
+ "Data column no. 476",
149
+ "Data column no. 150",
150
+ "Data column no. 299",
151
+ "Data column no. 334",
152
+ "Data column no. 493",
153
+ "Data column no. 891",
154
+ "Data column no. 792",
155
+ "Data column no. 765",
156
+ "Data column no. 944",
157
+ "Data column no. 374",
158
+ "Data column no. 660",
159
+ "Data column no. 188",
160
+ "Data column no. 920",
161
+ "Data column no. 330",
162
+ "Data column no. 396",
163
+ "Data column no. 680",
164
+ "Data column no. 769",
165
+ "Data column no. 308",
166
+ "Data column no. 375",
167
+ "Data column no. 474",
168
+ "Data column no. 79",
169
+ "Data column no. 871",
170
+ "Data column no. 729",
171
+ "Data column no. 664",
172
+ }
dreamcoder/domains/regex/main.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # analog of list.py for regex tasks. Responsible for actually running the task.
2
+
3
+ from dreamcoder.domains.regex.makeRegexTasks import makeOldTasks, makeLongTasks, makeShortTasks, makeWordTasks, makeNumberTasks, makeHandPickedTasks, makeNewTasks, makeNewNumberTasks
4
+ from dreamcoder.domains.regex.regexPrimitives import basePrimitives, altPrimitives, easyWordsPrimitives, alt2Primitives, concatPrimitives, reducedConcatPrimitives, strConstConcatPrimitives, PRC
5
+ from dreamcoder.dreamcoder import explorationCompression, Task
6
+ from dreamcoder.grammar import Grammar
7
+ from dreamcoder.likelihoodModel import add_cutoff_values, add_string_constants
8
+ from dreamcoder.program import Abstraction, Application
9
+ from dreamcoder.type import tpregex
10
+ from dreamcoder.utilities import eprint, flatten, testTrainSplit, POSITIVEINFINITY
11
+
12
+ import random
13
+ import math
14
+ import pregex as pre
15
+ import os
16
+
17
+ try:
18
+ from dreamcoder.recognition import RecurrentFeatureExtractor, JSONFeatureExtractor
19
+ class LearnedFeatureExtractor(RecurrentFeatureExtractor):
20
+ H = 64
21
+ special = 'regex'
22
+
23
+ def tokenize(self, examples):
24
+ def sanitize(l): return [z if z in self.lexicon else "?"
25
+ for z_ in l
26
+ for z in (z_ if isinstance(z_, list) else [z_])]
27
+
28
+ tokenized = []
29
+ for xs, y in examples:
30
+ if isinstance(y, list):
31
+ y = ["LIST_START"] + y + ["LIST_END"]
32
+ else:
33
+ y = [y]
34
+ y = sanitize(y)
35
+ if len(y) > self.maximumLength:
36
+ return None
37
+
38
+ serializedInputs = []
39
+ for xi, x in enumerate(xs):
40
+ if isinstance(x, list):
41
+ x = ["LIST_START"] + x + ["LIST_END"]
42
+ else:
43
+ x = [x]
44
+ x = sanitize(x)
45
+ if len(x) > self.maximumLength:
46
+ return None
47
+ serializedInputs.append(x)
48
+
49
+ tokenized.append((tuple(serializedInputs), y))
50
+
51
+ return tokenized
52
+
53
+ def __init__(self, tasks, testingTasks=[], cuda=False):
54
+ self.lexicon = set(flatten((t.examples for t in tasks + testingTasks), abort=lambda x: isinstance(
55
+ x, str))).union({"LIST_START", "LIST_END", "?"})
56
+
57
+ self.num_examples_list = [len(t.examples) for t in tasks]
58
+
59
+ # Calculate the maximum length
60
+ self.maximumLength = POSITIVEINFINITY
61
+ self.maximumLength = max(len(l)
62
+ for t in tasks + testingTasks
63
+ for xs, y in self.tokenize(t.examples)
64
+ for l in [y] + [x for x in xs])
65
+
66
+ super(
67
+ LearnedFeatureExtractor,
68
+ self).__init__(
69
+ lexicon=list(
70
+ self.lexicon),
71
+ tasks=tasks,
72
+ cuda=cuda,
73
+ H=self.H,
74
+ bidirectional=True)
75
+ self.parallelTaskOfProgram = False
76
+
77
+
78
+ def taskOfProgram(self, p, t):
79
+ #raise NotImplementedError
80
+ num_examples = random.choice(self.num_examples_list)
81
+
82
+ p = p.visit(ConstantInstantiateVisitor.SINGLE)
83
+
84
+ preg = p.evaluate([])(pre.String(""))
85
+ t = Task("Helm", t, [((), list(preg.sample())) for _ in range(num_examples) ])
86
+ return t
87
+ except: pass
88
+ #in init: loop over tasks, save lengths,
89
+
90
+
91
+ class ConstantInstantiateVisitor(object):
92
+ def __init__(self):
93
+ self.regexes = [
94
+ pre.create(".+"),
95
+ pre.create("\d+"),
96
+ pre.create("\w+"),
97
+ pre.create("\s+"),
98
+ pre.create("\\u+"),
99
+ pre.create("\l+")]
100
+
101
+ def primitive(self, e):
102
+ if e.name == "r_const":
103
+ #return Primitive("STRING", e.tp, random.choice(self.words))
104
+ s = random.choice(self.regexes).sample() #random string const
105
+ s = pre.String(s)
106
+ e.value = PRC(s,arity=0)
107
+ return e
108
+
109
+ def invented(self, e): return e.body.visit(self)
110
+
111
+ def index(self, e): return e
112
+
113
+ def application(self, e):
114
+ return Application(e.f.visit(self), e.x.visit(self))
115
+
116
+ def abstraction(self, e):
117
+ return Abstraction(e.body.visit(self))
118
+ #TODO fix
119
+
120
+
121
+
122
+
123
+ class MyJSONFeatureExtractor(JSONFeatureExtractor):
124
+ N_EXAMPLES = 5
125
+
126
+ def _featuresOfProgram(self, program, tp):
127
+ try:
128
+ preg = program.evaluate([])
129
+ # if 'left_paren' in program.show(False):
130
+ #eprint("string_pregex:", string_pregex)
131
+ #eprint("string_pregex:", string_pregex)
132
+
133
+ except IndexError:
134
+ # free variable
135
+ return None
136
+ except Exception as e:
137
+ eprint("Exception during evaluation:", e)
138
+ if "Attempt to evaluate fragment variable" in e:
139
+ eprint("program (bc fragment error)", program)
140
+ return None
141
+
142
+ examples = []
143
+
144
+ for _ in range(self.N_EXAMPLES * 5): # oh this is arbitrary ig
145
+
146
+ try:
147
+ y = preg.sample() # TODO
148
+
149
+ #this line should keep inputs short, so that helmholtzbatch can be large
150
+ #allows it to try other samples
151
+ #(Could also return None off the bat... idk which is better)
152
+ #if len(y) > 20:
153
+ # continue
154
+ #eprint(tp, program, x, y)
155
+ examples.append(y)
156
+ except BaseException:
157
+ continues
158
+ if len(examples) >= self.N_EXAMPLES:
159
+ break
160
+ else:
161
+ return None
162
+ return examples # changed to list_features(examples) from examples
163
+
164
+
165
+ def regex_options(parser):
166
+ parser.add_argument("--maxTasks", type=int,
167
+ default=500,
168
+ help="truncate tasks to fit within this boundary")
169
+ parser.add_argument(
170
+ "--maxExamples",
171
+ type=int,
172
+ default=10,
173
+ help="truncate number of examples per task to fit within this boundary")
174
+ parser.add_argument("--tasks",
175
+ default="long",
176
+ help="which tasks to use",
177
+ choices=["old", "short", "long", "words", "number", "handpicked", "new", "newNumber"])
178
+ parser.add_argument("--primitives",
179
+ default="concat",
180
+ help="Which primitive set to use",
181
+ choices=["base", "alt1", "easyWords", "alt2", "concat", "reduced", "strConst"])
182
+ parser.add_argument("--extractor", type=str,
183
+ choices=["hand", "deep", "learned", "json"],
184
+ default="learned") # if i switch to json it breaks
185
+ parser.add_argument("--split", metavar="TRAIN_RATIO",
186
+ type=float,
187
+ default=0.8,
188
+ help="split test/train")
189
+ parser.add_argument("-H", "--hidden", type=int,
190
+ default=256,
191
+ help="number of hidden units")
192
+ parser.add_argument("--likelihoodModel",
193
+ default="probabilistic",
194
+ help="likelihood Model",
195
+ choices=["probabilistic", "all-or-nothing"])
196
+ parser.add_argument("--topk_use_map",
197
+ dest="topk_use_only_likelihood",
198
+ action="store_false")
199
+ parser.add_argument("--debug",
200
+ dest="debug",
201
+ action="store_true")
202
+ parser.add_argument("--ll_cutoff",
203
+ dest="use_ll_cutoff",
204
+ nargs='*',
205
+ default=False,
206
+ help="use ll cutoff for training tasks (for probabilistic likelihood model only). default is False,")
207
+ parser.add_argument("--use_str_const",
208
+ action="store_true",
209
+ help="use string constants")
210
+
211
+ """parser.add_argument("--stardecay",
212
+ type=float,
213
+ dest="stardecay",
214
+ default=0.5,
215
+ help="p value for kleenestar and plus")"""
216
+
217
+ # Lucas recommends putting a struct with the definitions of the primitives here.
218
+ # TODO:
219
+ # Build likelihood funciton
220
+ # modify NN
221
+ # make primitives
222
+ # make tasks
223
+
224
+
225
+ def main(args):
226
+ """
227
+ Takes the return value of the `commandlineArguments()` function as input and
228
+ trains/tests the model on regular expressions.
229
+ """
230
+ #for dreaming
231
+
232
+ #parse use_ll_cutoff
233
+ use_ll_cutoff = args.pop('use_ll_cutoff')
234
+ if not use_ll_cutoff is False:
235
+
236
+ #if use_ll_cutoff is a list of strings, then train_ll_cutoff and train_ll_cutoff
237
+ #will be tuples of that string followed by the actual model
238
+
239
+ if len(use_ll_cutoff) == 1:
240
+ train_ll_cutoff = use_ll_cutoff[0] # make_cutoff_model(use_ll_cutoff[0], tasks))
241
+ test_ll_cutoff = use_ll_cutoff[0] # make_cutoff_model(use_ll_cutoff[0], tasks))
242
+ else:
243
+ assert len(use_ll_cutoff) == 2
244
+ train_ll_cutoff = use_ll_cutoff[0] #make_cutoff_model(use_ll_cutoff[0], tasks))
245
+ test_ll_cutoff = use_ll_cutoff[1] #make_cutoff_model(use_ll_cutoff[1], tasks))
246
+ else:
247
+ train_ll_cutoff = None
248
+ test_ll_cutoff = None
249
+
250
+
251
+ regexTasks = {"old": makeOldTasks,
252
+ "short": makeShortTasks,
253
+ "long": makeLongTasks,
254
+ "words": makeWordTasks,
255
+ "number": makeNumberTasks,
256
+ "handpicked": makeHandPickedTasks,
257
+ "new": makeNewTasks,
258
+ "newNumber": makeNewNumberTasks
259
+ }[args.pop("tasks")]
260
+
261
+ tasks = regexTasks() # TODO
262
+ eprint("Generated", len(tasks), "tasks")
263
+
264
+ maxTasks = args.pop("maxTasks")
265
+ if len(tasks) > maxTasks:
266
+ eprint("Unwilling to handle {} tasks, truncating..".format(len(tasks)))
267
+ seed = 42 # previously this was hardcoded and never changed
268
+ random.seed(seed)
269
+ random.shuffle(tasks)
270
+ del tasks[maxTasks:]
271
+
272
+ maxExamples = args.pop("maxExamples")
273
+
274
+
275
+ split = args.pop("split")
276
+ test, train = testTrainSplit(tasks, split)
277
+ eprint("Split tasks into %d/%d test/train" % (len(test), len(train)))
278
+
279
+
280
+ test = add_cutoff_values(test, test_ll_cutoff)
281
+ train = add_cutoff_values(train, train_ll_cutoff)
282
+ eprint("added cutoff values to tasks, train: ", train_ll_cutoff, ", test:", test_ll_cutoff )
283
+
284
+
285
+ if args.pop("use_str_const"):
286
+ assert args["primitives"] == "strConst" or args["primitives"] == "reduced"
287
+ ConstantInstantiateVisitor.SINGLE = \
288
+ ConstantInstantiateVisitor()
289
+ test = add_string_constants(test)
290
+ train = add_string_constants(train)
291
+ eprint("added string constants to test and train")
292
+
293
+ for task in test + train:
294
+ if len(task.examples) > maxExamples:
295
+ task.examples = task.examples[:maxExamples]
296
+
297
+ task.specialTask = ("regex", {"cutoff": task.ll_cutoff, "str_const": task.str_const})
298
+ task.examples = [(xs, [y for y in ys ])
299
+ for xs,ys in task.examples ]
300
+ task.maxParameters = 1
301
+
302
+ # from list stuff
303
+ primtype = args.pop("primitives")
304
+ prims = {"base": basePrimitives,
305
+ "alt1": altPrimitives,
306
+ "alt2": alt2Primitives,
307
+ "easyWords": easyWordsPrimitives,
308
+ "concat": concatPrimitives,
309
+ "reduced": reducedConcatPrimitives,
310
+ "strConst": strConstConcatPrimitives
311
+ }[primtype]
312
+
313
+ extractor = {
314
+ "learned": LearnedFeatureExtractor,
315
+ "json": MyJSONFeatureExtractor
316
+ }[args.pop("extractor")]
317
+
318
+ extractor.H = args.pop("hidden")
319
+
320
+ #stardecay = args.stardecay
321
+ #stardecay = args.pop('stardecay')
322
+ #decaystr = 'd' + str(stardecay)
323
+ import datetime
324
+
325
+ timestamp = datetime.datetime.now().isoformat()
326
+ outputDirectory = "experimentOutputs/regex/%s"%timestamp
327
+ os.system("mkdir -p %s"%outputDirectory)
328
+
329
+ args.update({
330
+ "featureExtractor": extractor,
331
+ "outputPrefix": "%s/regex"%(outputDirectory),
332
+ "evaluationTimeout": 0.005,
333
+ "topk_use_only_likelihood": True,
334
+ "maximumFrontier": 10,
335
+ "compressor": args.get("compressor","ocaml")
336
+ })
337
+ ####
338
+
339
+
340
+ # use the
341
+ #prim_list = prims(stardecay)
342
+ prim_list = prims()
343
+ specials = ["r_kleene", "r_plus", "r_maybe", "r_alt", "r_concat"]
344
+ n_base_prim = len(prim_list) - len(specials)
345
+
346
+ productions = [
347
+ (math.log(0.5 / float(n_base_prim)),
348
+ prim) if prim.name not in specials else (
349
+ math.log(0.10),
350
+ prim) for prim in prim_list]
351
+
352
+
353
+ baseGrammar = Grammar.fromProductions(productions, continuationType=tpregex)
354
+ #baseGrammar = Grammar.uniform(prims())
355
+
356
+ #for i in range(100):
357
+ # eprint(baseGrammar.sample(tpregex))
358
+
359
+ #eprint(baseGrammar)
360
+ #explore
361
+ test_stuff = args.pop("debug")
362
+ if test_stuff:
363
+ eprint(baseGrammar)
364
+ eprint("sampled programs from prior:")
365
+ for i in range(100): #100
366
+ eprint(baseGrammar.sample(test[0].request,maximumDepth=1000))
367
+ eprint("""half the probability mass is on higher-order primitives.
368
+ Therefore half of enumerated programs should have more than one node.
369
+ However, we do not observe this.
370
+ Instead we see a very small fraction of programs have more than one node.
371
+ So something seems to be wrong with grammar.sample.
372
+
373
+ Furthermore: observe the large print statement above.
374
+ This prints the candidates for sampleDistribution in grammar.sample.
375
+ the first element of each tuple is the probability passed into sampleDistribution.
376
+ Half of the probability mass should be on the functions, but instead they are equally
377
+ weighted with the constants. If you look at the grammar above, this is an error!!!!
378
+ """)
379
+ assert False
380
+
381
+ del args["likelihoodModel"]
382
+ explorationCompression(baseGrammar, train,
383
+ testingTasks = test,
384
+ **args)
dreamcoder/domains/regex/makeRegexTasks.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dill
2
+ import os
3
+ import json
4
+ from string import printable
5
+
6
+ import sys
7
+ try:
8
+ from pregex import pregex
9
+ except:
10
+ print("Failure to load pregex. This is only acceptable if using pypy",file=sys.stderr)
11
+
12
+ from dreamcoder.task import Task
13
+ from dreamcoder.type import tpregex, arrow
14
+ from dreamcoder.utilities import get_data_dir
15
+
16
+
17
+ def makeOldTasks():
18
+ # a series of tasks
19
+
20
+ taskfile = os.path.join(get_data_dir(), 'data_filtered.json')
21
+ #task_list = pickle.load(open(taskfile, 'rb'))
22
+
23
+ with open(taskfile) as f:
24
+ file_contents = f.read()
25
+ task_list = json.loads(file_contents)
26
+
27
+
28
+ # if I were to just dump all of them:
29
+ regextasks = [
30
+ Task("Luke data column no." + str(i),
31
+ arrow(tpregex, tpregex),
32
+ [((), example) for example in task_list[i]]
33
+ ) for i in range(len(task_list))]
34
+
35
+
36
+ """ regextasks = [
37
+ Task("length bool", arrow(none,tstr),
38
+ [((l,), len(l))
39
+ for _ in range(10)
40
+ for l in [[flip() for _ in range(randint(0,10)) ]] ]),
41
+ Task("length int", arrow(none,tstr),
42
+ [((l,), len(l))
43
+ for _ in range(10)
44
+ for l in [randomList()] ]),
45
+ ]
46
+ """
47
+ return regextasks # some list of tasks
48
+
49
+
50
+
51
+
52
+ def makeShortTasks():
53
+
54
+ #load new data:
55
+
56
+ taskfile = os.path.join(get_data_dir(), "regex_data_csv_900.p")
57
+
58
+ with open(taskfile, 'rb') as handle:
59
+ data = dill.load(handle)
60
+
61
+ tasklist = data[0][:100] #a list of indices
62
+
63
+ regextasks = [
64
+ Task("Data column no. " + str(i),
65
+ arrow(tpregex, tpregex),
66
+ [((), example) for example in task]
67
+ ) for i, task in enumerate(tasklist)]
68
+
69
+
70
+
71
+ return regextasks
72
+
73
+ def makeLongTasks():
74
+
75
+ #load new data:
76
+
77
+ taskfile = os.path.join(get_data_dir(), "regex_data_csv_900.p")
78
+
79
+ with open(taskfile, 'rb') as handle:
80
+ data = dill.load(handle)
81
+
82
+ tasklist = data[0] #a list of indices
83
+
84
+ regextasks = [
85
+ Task("Data column no. " + str(i),
86
+ arrow(tpregex, tpregex),
87
+ [((), example) for example in task]
88
+ ) for i, task in enumerate(tasklist)]
89
+
90
+
91
+
92
+ return regextasks
93
+
94
+ def makeWordTasks():
95
+
96
+ #load new data:
97
+
98
+ taskfile = os.path.join(get_data_dir(), "regex_data_csv_900.p")
99
+
100
+ with open(taskfile, 'rb') as handle:
101
+ data = dill.load(handle)
102
+
103
+ tasklist = data[0] #a list of indices
104
+
105
+
106
+
107
+
108
+ all_upper = [0, 2, 8, 9, 10, 11, 12, 17, 18, 19, 20, 22]
109
+ all_lower = [1]
110
+
111
+ # match_col(data[0],'\\u(\l+)')
112
+ one_capital_lower_plus = [144, 200, 241, 242, 247, 296, 390, 392, 444, 445, 481, 483, 485, 489, 493, 542, 549, 550, 581]
113
+
114
+ #match_col(data[0],'(\l ?)+')
115
+ lower_with_maybe_spaces = [1, 42, 47, 99, 100, 102, 201, 246, 248, 293, 294, 345, 437, 545, 590]
116
+
117
+ #match_col(data[0],'(\\u\l+ ?)+')
118
+ capital_then_lower_maybe_spaces = [144, 200, 241, 242, 247, 296, 390, 392, 395, 438, 444, 445, 481, 483, 484, 485, 487, 489, 493, 494, 542, 546, 549, 550, 578, 581, 582, 588, 591, 624, 629]
119
+
120
+ #match_col(data[0],'(\\u+ ?)+')
121
+ all_caps_spaces = [0, 2, 8, 9, 10, 11, 12, 17, 18, 19, 20, 22, 25, 26, 35, 36, 43, 45, 46, 49, 50, 52, 56, 59, 87, 89, 95, 101, 140, 147, 148, 149, 199, 332, 336, 397, 491, 492, 495, 580, 610]
122
+
123
+ #one_capital_and_lower = [566, 550, 549, 542, 505, 493, 494, 489, 488, 485, 483, 481, 445, 444, 438, 296, 241, 242, 200, ]
124
+ #all_lower_with_a_space = [545]
125
+ #all_lower_maybe_space = [534]
126
+ #one_capital_lower_maybe_spaces = [259, 262, 263, 264]
127
+
128
+
129
+ #full_list = test_list + train_list
130
+ train_list = []
131
+ full_list = all_upper + all_lower + one_capital_lower_plus + lower_with_maybe_spaces + capital_then_lower_maybe_spaces + all_caps_spaces
132
+
133
+ regextasks = [
134
+ Task("Data column no. " + str(i),
135
+ arrow(tpregex, tpregex),
136
+ [((), example) for example in task]
137
+ ) for i, task in enumerate(tasklist) if i in full_list ]
138
+
139
+ for i in train_list:
140
+ regextasks[i].mustTrain = True
141
+
142
+
143
+ return regextasks
144
+
145
+ def makeNumberTasks():
146
+
147
+ #load new data:
148
+
149
+ taskfile = os.path.join(get_data_dir(), "regex_data_csv_900.p")
150
+
151
+ with open(taskfile, 'rb') as handle:
152
+ data = dill.load(handle)
153
+
154
+ tasklist = data[0] #a list of indices
155
+
156
+
157
+
158
+ #match_col(data[0],'\d*\.\d*')
159
+ raw_decimals = [121, 122, 163, 164, 165, 170, 172, 173, 175, 178, 218, 228, 230, 231, 252, 253,
160
+ 254, 258, 259, 305, 320, 330, 334, 340, 348, 350, 351, 352, 353, 355, 357, 358, 361, 363, 364,
161
+ 371, 380, 382, 409, 410, 411, 447, 448, 449, 450, 458, 469, 471, 533, 562, 564]
162
+
163
+
164
+ decimals_pos_neg_dollar = [3, 4, 5, 6, 7, 13, 16, 24, 27, 28, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40,
165
+ 53, 54, 55, 57, 58, 60, 61, 63, 64, 65, 66, 68, 69, 70, 71, 73, 74, 77, 78, 80, 81, 103, 104, 105,
166
+ 106, 107, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 121, 122, 123, 124, 125, 126, 128,
167
+ 129, 131, 132, 134, 135, 139, 146, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
168
+ 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 180, 181, 182, 183, 184, 185, 186,
169
+ 193, 194, 195, 204, 205, 207, 209, 210, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 223, 224,
170
+ 225, 226, 227, 228, 229, 230, 231, 232, 249, 250, 251, 252, 253, 254, 255, 256, 258, 259, 260, 261,
171
+ 263, 266, 267, 270, 271, 272, 277, 299, 301, 302, 305, 306, 307, 309, 312, 313, 315, 319, 320, 324,
172
+ 326, 327, 330, 334, 340, 348, 350, 351, 352, 353, 354, 355, 356, 357, 358, 361, 362, 363, 364, 368,
173
+ 371, 373, 377, 380, 382, 400, 401, 402, 403, 405, 406, 409, 410, 411, 413, 435, 439, 446, 447, 448,
174
+ 449, 450, 451, 452, 453, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 469, 470, 471, 477,
175
+ 498, 500, 502, 503, 507, 512, 518, 519, 520, 532, 533, 553, 554, 555, 556, 557, 558, 559, 560, 561,
176
+ 562, 564, 565, 572, 577]
177
+
178
+ #match_col(data[0],'(\d*,?\d*)+')
179
+ commas = []
180
+ #match_col(data[0],'(\d*,?\d*)+')
181
+ commas_and_all = []
182
+
183
+ #full_list = test_list + train_list
184
+ train_list = []
185
+ full_list = decimals_pos_neg_dollar
186
+
187
+ regextasks = [
188
+ Task("Data column no. " + str(i),
189
+ arrow(tpregex, tpregex),
190
+ [((), example) for example in task]
191
+ ) for i, task in enumerate(tasklist) if i in full_list ]
192
+
193
+ for i in train_list:
194
+ regextasks[i].mustTrain = True
195
+
196
+
197
+ return regextasks
198
+
199
+
200
+ def makeHandPickedTasks():
201
+
202
+ #load new data:
203
+
204
+ taskfile = os.path.join(get_data_dir(), "regex_data_csv_900.p")
205
+
206
+ with open(taskfile, 'rb') as handle:
207
+ data = dill.load(handle)
208
+
209
+ tasklist = data[0] #a list of indices
210
+
211
+
212
+ full_list = list(range(199)) + \
213
+ [209,218,222,223,224,225,226] + \
214
+ list(range(222,233)) + \
215
+ [235,237,238,239,243,244,245,252,253,254,255,257,258,259,260,261,264,265,269,272,274] + \
216
+ list(range(275,291)) + \
217
+ [295,297,300,303,304,305,306,310,311,312,314,315,316,320,321,323,327,329,330,333,334,335,337,338,339,340,341,342,343,344] + \
218
+ list(range(348,359)) + \
219
+ [361,369,373,379,380,382,387,403,405,407,408] + \
220
+ list(range(409,417)) + \
221
+ list(range(418,437)) + \
222
+ list(range(440,444)) + \
223
+ list(range(446,452)) + \
224
+ list(range(456,460)) + \
225
+ list(range(466,472)) + \
226
+ [503,504]
227
+
228
+
229
+ regextasks = [
230
+ Task("Data column no. " + str(i),
231
+ arrow(tpregex, tpregex),
232
+ [((), example) for example in task]
233
+ ) for i, task in enumerate(tasklist) if i in full_list ]
234
+
235
+ #for i in train_list:
236
+ # regextasks[i].mustTrain = True
237
+
238
+
239
+ return regextasks
240
+
241
+ def makeNewTasks(include_only=None):
242
+
243
+ #load new data:
244
+
245
+ taskfile = os.path.join(get_data_dir(), "csv_filtered_all_background_novel.p")
246
+
247
+ with open(taskfile, 'rb') as handle:
248
+ data = dill.load(handle)
249
+
250
+ tasklist = data['background'] #a list of indices
251
+
252
+ if include_only:
253
+ regextasks = [
254
+ Task("Data column no. " + str(i),
255
+ arrow(tpregex, tpregex),
256
+ [((), example) for example in task['train']]
257
+ ) for i, task in enumerate(tasklist) if i in include_only]
258
+ else:
259
+ regextasks = [
260
+ Task("Data column no. " + str(i),
261
+ arrow(tpregex, tpregex),
262
+ [((), example) for example in task['train']]
263
+ ) for i, task in enumerate(tasklist)]
264
+
265
+ #for i in train_list:
266
+ # regextasks[i].mustTrain = True
267
+
268
+ return regextasks
269
+ REGEXTASKS = None
270
+ def regexHeldOutExamples(task, include_only=None):
271
+
272
+ #load new data:
273
+ global REGEXTASKS
274
+ if REGEXTASKS is None:
275
+ taskfile = os.path.join(get_data_dir(), "csv_filtered_all_background_novel.p")
276
+
277
+ with open(taskfile, 'rb') as handle:
278
+ data = dill.load(handle)
279
+
280
+ tasklist = data['background'] #a list of indices
281
+
282
+ if include_only:
283
+ regextasks = [
284
+ Task("Data column no. " + str(i),
285
+ arrow(tpregex, tpregex),
286
+ [((), example) for example in _task['test']]
287
+ ) for i, _task in enumerate(tasklist) if i in include_only]
288
+ else:
289
+ regextasks = [
290
+ Task("Data column no. " + str(i),
291
+ arrow(tpregex, tpregex),
292
+ [((), example) for example in _task['test']]
293
+ ) for i, _task in enumerate(tasklist)]
294
+
295
+ #for i in train_list:
296
+ # regextasks[i].mustTrain = True
297
+
298
+ REGEXTASKS = {t.name: t.examples for t in regextasks}
299
+ fullTask = REGEXTASKS[task.name]
300
+ return fullTask
301
+
302
+
303
+
304
+ def makeNewNumberTasks():
305
+
306
+ tasks = makeNewTasks()
307
+ numberTasks = [t for t in tasks if not any(p in ex for p in printable[10:62] for _, ex in t.examples)]
308
+ return numberTasks
309
+
310
+
311
+ # a helper function which takes a list of lists and sees which match a specific regex.
312
+ def match_col(dataset, rstring):
313
+ r = pregex.create(rstring)
314
+ matches = []
315
+ for i, col in enumerate(dataset):
316
+ score = sum([r.match(example) for example in col])
317
+ if score != float('-inf'):
318
+ matches.append(i)
319
+ return matches
320
+
321
+ if __name__ == "__main__":
322
+ import argparse
323
+ parser = argparse.ArgumentParser()
324
+ parser.add_argument("--include_only",
325
+ default=None,
326
+ nargs="+",
327
+ type=int)
328
+ args = parser.parse_args()
329
+
330
+
331
+ def show_tasks(dataset):
332
+ task_list = []
333
+ for task in dataset:
334
+ print(task.name)
335
+ print([example[1] for example in task.examples[:20]])
336
+ task_list.append([example[1] for example in task.examples])
337
+ return task_list
338
+
339
+ task = {"number": makeNumberTasks,
340
+ "words": makeWordTasks,
341
+ "all": makeLongTasks,
342
+ "new": makeNewTasks}['new']
343
+
344
+
345
+ x = show_tasks(task(args.include_only))
346
+
347
+
dreamcoder/domains/regex/regexPrimitives.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from dreamcoder.program import Primitive
3
+ from dreamcoder.grammar import Grammar
4
+ from dreamcoder.type import arrow, tpregex
5
+ from string import printable
6
+
7
+ try:
8
+ from pregex import pregex
9
+ except:
10
+ print("Failure to load pregex. This is only acceptable if using pypy",file=sys.stderr)
11
+
12
+
13
+ # evaluation to regular regex form. then I can unflatten using Luke's stuff.
14
+
15
+
16
+ def _kleene(x): return pregex.KleeneStar(x, p=0.25)
17
+
18
+
19
+ def _plus(x): return pregex.Plus(x, p=0.25)
20
+
21
+
22
+ def _maybe(x): return pregex.Maybe(x)
23
+
24
+
25
+ # maybe should be reversed#"(" + x + "|" + y + ")"
26
+ def _alt(x): return lambda y: pregex.Alt([x, y])
27
+
28
+
29
+ def _concat(x): return lambda y: pregex.Concat([x, y]) # "(" + x + y + ")"
30
+
31
+
32
+ #For sketch:
33
+ def _kleene_5(x): return pregex.KleeneStar(x)
34
+
35
+ def _plus_5(x): return pregex.Plus(x)
36
+
37
+
38
+ disallowed = [
39
+ ("#", "hash"),
40
+ ("!", "bang"),
41
+ ("\"", "double_quote"),
42
+ ("$", "dollar"),
43
+ ("%", "percent"),
44
+ ("&", "ampersand"),
45
+ ("'", "single_quote"),
46
+ (")", "left_paren"),
47
+ ("(", "right_paren"),
48
+ ("*", "astrisk"),
49
+ ("+", "plus"),
50
+ (",", "comma"),
51
+ ("-", "dash"),
52
+ (".", "period"),
53
+ ("/", "slash"),
54
+ (":", "colon"),
55
+ (";", "semicolon"),
56
+ ("<", "less_than"),
57
+ ("=", "equal"),
58
+ (">", "greater_than"),
59
+ ("?", "question_mark"),
60
+ ("@", "at"),
61
+ ("[", "left_bracket"),
62
+ ("\\", "backslash"),
63
+ ("]", "right_bracket"),
64
+ ("^", "carrot"),
65
+ ("_", "underscore"),
66
+ ("`", "backtick"),
67
+ ("|", "bar"),
68
+ ("}", "right_brace"),
69
+ ("{", "left_brace"),
70
+ ("~", "tilde"),
71
+ (" ", "space"),
72
+ ("\t", "tab")
73
+ ]
74
+
75
+ disallowed_list = [char for char, _ in disallowed]
76
+
77
+ class PRC(): #PregexContinuation
78
+ def __init__(self, f, arity=0, args=[]):
79
+ self.f = f
80
+ self.arity = arity
81
+ self.args = args
82
+
83
+ def __call__(self, pre):
84
+
85
+ if self.arity == len(self.args):
86
+ if self.arity == 0: return pregex.Concat([self.f, pre])
87
+ elif self.arity == 1: return pregex.Concat([self.f(*self.args), pre])
88
+ else: return pregex.Concat([self.f(self.args), pre]) #this line is bad, need brackets around input to f if f is Alt
89
+ else: return PRC(self.f, self.arity, args=self.args+[pre(pregex.String(""))])
90
+
91
+
92
+ def concatPrimitives():
93
+ return [Primitive("string_" + i, arrow(tpregex, tpregex), PRC(pregex.String(i))) for i in printable[:-4] if i not in disallowed_list
94
+ ] + [
95
+ Primitive("string_" + name, arrow(tpregex, tpregex), PRC(pregex.String(char))) for char, name in disallowed
96
+ ] + [
97
+ Primitive("r_dot", arrow(tpregex, tpregex), PRC(pregex.dot)),
98
+ Primitive("r_d", arrow(tpregex, tpregex), PRC(pregex.d)),
99
+ Primitive("r_s", arrow(tpregex, tpregex), PRC(pregex.s)),
100
+ Primitive("r_w", arrow(tpregex, tpregex), PRC(pregex.w)),
101
+ Primitive("r_l", arrow(tpregex, tpregex), PRC(pregex.l)),
102
+ Primitive("r_u", arrow(tpregex, tpregex), PRC(pregex.u)),
103
+ #todo
104
+ Primitive("r_kleene", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.KleeneStar,1)),
105
+ Primitive("r_plus", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Plus,1)),
106
+ Primitive("r_maybe", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Maybe,1)),
107
+ Primitive("r_alt", arrow(arrow(tpregex, tpregex) , arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Alt, 2)),
108
+ ]
109
+
110
+ def strConstConcatPrimitives():
111
+ return [Primitive("string_" + i, arrow(tpregex, tpregex), PRC(pregex.String(i))) for i in printable[:-4] if i not in disallowed_list
112
+ ] + [
113
+ Primitive("string_" + name, arrow(tpregex, tpregex), PRC(pregex.String(char))) for char, name in disallowed
114
+ ] + [
115
+ Primitive("r_dot", arrow(tpregex, tpregex), PRC(pregex.dot)),
116
+ Primitive("r_d", arrow(tpregex, tpregex), PRC(pregex.d)),
117
+ Primitive("r_s", arrow(tpregex, tpregex), PRC(pregex.s)),
118
+ Primitive("r_w", arrow(tpregex, tpregex), PRC(pregex.w)),
119
+ Primitive("r_l", arrow(tpregex, tpregex), PRC(pregex.l)),
120
+ Primitive("r_u", arrow(tpregex, tpregex), PRC(pregex.u)),
121
+ #todo
122
+ Primitive("r_kleene", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.KleeneStar,1)),
123
+ Primitive("r_plus", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Plus,1)),
124
+ Primitive("r_maybe", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Maybe,1)),
125
+ Primitive("r_alt", arrow(arrow(tpregex, tpregex) , arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Alt, 2)),
126
+ ] + [
127
+ Primitive("r_const", arrow(tpregex, tpregex), None)
128
+ ]
129
+
130
+
131
+ def reducedConcatPrimitives():
132
+ #uses strConcat!!
133
+ #[Primitive("empty_string", arrow(tpregex, tpregex), PRC(pregex.String("")))
134
+ #] + [
135
+ return [Primitive("string_" + i, arrow(tpregex, tpregex), PRC(pregex.String(i))) for i in printable[:-4] if i not in disallowed_list
136
+ ] + [
137
+ Primitive("string_" + name, arrow(tpregex, tpregex), PRC(pregex.String(char))) for char, name in disallowed
138
+ ] + [
139
+ Primitive("r_dot", arrow(tpregex, tpregex), PRC(pregex.dot)),
140
+ Primitive("r_d", arrow(tpregex, tpregex), PRC(pregex.d)),
141
+ Primitive("r_s", arrow(tpregex, tpregex), PRC(pregex.s)),
142
+ #Primitive("r_w", arrow(tpregex, tpregex), PRC(pregex.w)),
143
+ Primitive("r_l", arrow(tpregex, tpregex), PRC(pregex.l)),
144
+ Primitive("r_u", arrow(tpregex, tpregex), PRC(pregex.u)),
145
+ #todo
146
+ Primitive("r_kleene", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.KleeneStar,1)),
147
+ #Primitive("r_plus", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Plus,1)),
148
+ Primitive("r_maybe", arrow(arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Maybe,1)),
149
+ Primitive("r_alt", arrow(arrow(tpregex, tpregex) , arrow(tpregex, tpregex), arrow(tpregex,tpregex)), PRC(pregex.Alt, 2)),
150
+ ] + [
151
+ Primitive("r_const", arrow(tpregex, tpregex), None)
152
+ ]
153
+
154
+
155
+ def sketchPrimitives():
156
+ return [Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
157
+ ] + [
158
+ Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
159
+ ] + [
160
+ Primitive("r_dot", tpregex, pregex.dot),
161
+ Primitive("r_d", tpregex, pregex.d),
162
+ Primitive("r_s", tpregex, pregex.s),
163
+ Primitive("r_w", tpregex, pregex.w),
164
+ Primitive("r_l", tpregex, pregex.l),
165
+ Primitive("r_u", tpregex, pregex.u),
166
+ Primitive("r_kleene", arrow(tpregex, tpregex), _kleene_5),
167
+ Primitive("r_plus", arrow(tpregex, tpregex), _plus_5),
168
+ Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
169
+ Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
170
+ Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
171
+ ]
172
+
173
+ def basePrimitives():
174
+ return [Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
175
+ ] + [
176
+ Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
177
+ ] + [
178
+ Primitive("r_dot", tpregex, pregex.dot),
179
+ Primitive("r_d", tpregex, pregex.d),
180
+ Primitive("r_s", tpregex, pregex.s),
181
+ Primitive("r_w", tpregex, pregex.w),
182
+ Primitive("r_l", tpregex, pregex.l),
183
+ Primitive("r_u", tpregex, pregex.u),
184
+ Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
185
+ Primitive("r_plus", arrow(tpregex, tpregex), _plus),
186
+ Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
187
+ Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
188
+ Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
189
+ ]
190
+
191
+
192
+
193
+ def altPrimitives():
194
+ return [
195
+ Primitive("empty_string", tpregex, pregex.String(""))
196
+ ] + [
197
+ Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
198
+ ] + [
199
+ Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
200
+ ] + [
201
+ Primitive("r_dot", tpregex, pregex.dot),
202
+ Primitive("r_d", tpregex, pregex.d),
203
+ Primitive("r_s", tpregex, pregex.s),
204
+ Primitive("r_w", tpregex, pregex.w),
205
+ Primitive("r_l", tpregex, pregex.l),
206
+ Primitive("r_u", tpregex, pregex.u),
207
+ Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
208
+ #Primitive("r_plus", arrow(tpregex, tpregex), _plus),
209
+ Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
210
+ Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
211
+ Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
212
+ ]
213
+
214
+ def alt2Primitives():
215
+ return [
216
+ Primitive("empty_string", tpregex, pregex.String(""))
217
+ ] + [
218
+ Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
219
+ ] + [
220
+ Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
221
+ ] + [
222
+ Primitive("r_dot", tpregex, pregex.dot),
223
+ Primitive("r_d", tpregex, pregex.d),
224
+ Primitive("r_s", tpregex, pregex.s),
225
+ Primitive("r_w", tpregex, pregex.w),
226
+ Primitive("r_l", tpregex, pregex.l),
227
+ Primitive("r_u", tpregex, pregex.u),
228
+ Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
229
+ #Primitive("r_plus", arrow(tpregex, tpregex), _plus),
230
+ #Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
231
+ Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
232
+ Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
233
+ ]
234
+
235
+ def easyWordsPrimitives():
236
+ return [
237
+ Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[10:62] if i not in disallowed_list
238
+ ] + [
239
+ Primitive("r_d", tpregex, pregex.d),
240
+ Primitive("r_s", tpregex, pregex.s),
241
+ #Primitive("r_w", tpregex, pregex.w),
242
+ Primitive("r_l", tpregex, pregex.l),
243
+ Primitive("r_u", tpregex, pregex.u),
244
+ Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
245
+ Primitive("r_plus", arrow(tpregex, tpregex), _plus),
246
+ Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
247
+ Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
248
+ Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
249
+ ]
250
+
251
+
252
+ #def _wrapper(x): return lambda y: y
253
+
254
+ #specials = [".","*","+","?","|"]
255
+ """
256
+ >>> import pregex as pre
257
+ >>> abc = pre.CharacterClass("abc", [0.1, 0.1, 0.8], name="MyConcept")
258
+ >>> abc.sample()
259
+ 'b'
260
+ >>> abc.sample()
261
+ 'c'
262
+ >>> abc.sample()
263
+ 'c'
264
+ >>> abc.match("c")
265
+ -0.2231435513142097
266
+ >>> abc.match("a")
267
+ -2.3025850929940455
268
+ >>> abc
269
+ MyConcept
270
+ >>> x = pre.KleeneStar(abc)
271
+ >>> x.match("aabbac")
272
+ -16.58809928020405
273
+ >>> x.sample()
274
+ ''
275
+ >>> x.sample()
276
+ ''
277
+ >>> x.sample()
278
+ 'cbcacc'
279
+ >>> x
280
+ (KleeneStar 0.5 MyConcept)
281
+ >>> str(x)
282
+ 'MyConcept*'
283
+ """
284
+
285
+
286
+ def emp_dot(corpus): return pregex.CharacterClass(printable[:-4], emp_distro_from_corpus(corpus, printable[:-4]), name=".")
287
+
288
+ def emp_d(corpus): return pregex.CharacterClass(printable[:10], emp_distro_from_corpus(corpus, printable[:10]), name="\\d")
289
+
290
+ #emp_s = pre.CharacterClass(slist, [], name="emp\\s") #may want to forgo this one.
291
+
292
+ def emp_dot_no_letter(corpus): return pregex.CharacterClass(printable[:10]+printable[62:], emp_distro_from_corpus(corpus, printable[:10]+printable[62:]), name=".")
293
+
294
+ def emp_w(corpus): return pregex.CharacterClass(printable[:62], emp_distro_from_corpus(corpus, printable[:62]), name="\\w")
295
+
296
+ def emp_l(corpus): return pregex.CharacterClass(printable[10:36], emp_distro_from_corpus(corpus, printable[10:36]), name="\\l")
297
+
298
+ def emp_u(corpus): return pregex.CharacterClass(printable[36:62], emp_distro_from_corpus(corpus, printable[36:62]), name="\\u")
299
+
300
+
301
+ def emp_distro_from_corpus(corpus, char_list):
302
+ from collections import Counter
303
+ c = Counter(char for task in corpus for example in task.examples for string in example[1] for char in string)
304
+ n = sum(c[char] for char in char_list)
305
+ return [c[char]/n for char in char_list]
306
+
307
+
308
+
309
+ def matchEmpericalPrimitives(corpus):
310
+ return lambda: [
311
+ Primitive("empty_string", tpregex, pregex.String(""))
312
+ ] + [
313
+ Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list
314
+ ] + [
315
+ Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
316
+ ] + [
317
+ Primitive("r_dot", tpregex, emp_dot(corpus) ),
318
+ Primitive("r_d", tpregex, emp_d(corpus) ),
319
+ Primitive("r_s", tpregex, pregex.s),
320
+ Primitive("r_w", tpregex, emp_w(corpus) ),
321
+ Primitive("r_l", tpregex, emp_l(corpus) ),
322
+ Primitive("r_u", tpregex, emp_u(corpus) ),
323
+ Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
324
+ #Primitive("r_plus", arrow(tpregex, tpregex), _plus),
325
+ Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
326
+ Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
327
+ Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
328
+ ]
329
+
330
+ def matchEmpericalNoLetterPrimitives(corpus):
331
+ return lambda: [
332
+ Primitive("empty_string", tpregex, pregex.String(""))
333
+ ] + [
334
+ Primitive("string_" + i, tpregex, pregex.String(i)) for i in printable[:-4] if i not in disallowed_list + list(printable[10:62])
335
+ ] + [
336
+ Primitive("string_" + name, tpregex, pregex.String(char)) for char, name in disallowed
337
+ ] + [
338
+ Primitive("r_dot", tpregex, emp_dot_no_letter(corpus) ),
339
+ Primitive("r_d", tpregex, emp_d(corpus) ),
340
+ Primitive("r_s", tpregex, pregex.s),
341
+ Primitive("r_kleene", arrow(tpregex, tpregex), _kleene),
342
+ #Primitive("r_plus", arrow(tpregex, tpregex), _plus),
343
+ #Primitive("r_maybe", arrow(tpregex, tpregex), _maybe),
344
+ Primitive("r_alt", arrow(tpregex, tpregex, tpregex), _alt),
345
+ Primitive("r_concat", arrow(tpregex, tpregex, tpregex), _concat),
346
+ ]
347
+
348
+
349
+ if __name__=='__main__':
350
+ concatPrimitives()
351
+ from dreamcoder.program import Program
352
+
353
+ p=Program.parse("(lambda (r_kleene (lambda (r_maybe (lambda (string_x $0)) $0)) $0))")
354
+ print(p)
355
+ print(p.runWithArguments([pregex.String("")]))
356
+
357
+ prims = concatPrimitives()
358
+ g = Grammar.uniform(prims)
359
+
360
+ for i in range(100):
361
+ prog = g.sample(arrow(tpregex,tpregex))
362
+ preg = prog.runWithArguments([pregex.String("")])
363
+ print("preg:", preg.__repr__())
364
+ print("sample:", preg.sample())
365
+
366
+
367
+
dreamcoder/domains/text/__init__.py ADDED
File without changes
dreamcoder/domains/text/main.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.dreamcoder import ecIterator
2
+ from dreamcoder.domains.text.makeTextTasks import makeTasks, loadPBETasks
3
+ from dreamcoder.domains.text.textPrimitives import primitives
4
+ from dreamcoder.domains.list.listPrimitives import bootstrapTarget
5
+ from dreamcoder.enumeration import *
6
+
7
+ import os
8
+ import datetime
9
+ import random
10
+ from functools import reduce
11
+ import dill
12
+
13
+
14
+ class ConstantInstantiateVisitor(object):
15
+ def __init__(self, words):
16
+ self.words = words
17
+
18
+ def primitive(self, e):
19
+ if e.name == "STRING":
20
+ return Primitive("STRING", e.tp, random.choice(self.words))
21
+ return e
22
+
23
+ def invented(self, e): return e.body.visit(self)
24
+
25
+ def index(self, e): return e
26
+
27
+ def application(self, e):
28
+ return Application(e.f.visit(self), e.x.visit(self))
29
+
30
+ def abstraction(self, e):
31
+ return Abstraction(e.body.visit(self))
32
+
33
+
34
+ try:
35
+ from dreamcoder.recognition import *
36
+ class LearnedFeatureExtractor(RecurrentFeatureExtractor):
37
+ special = 'string'
38
+
39
+ def tokenize(self, examples):
40
+ def tokenize_example(xs,y):
41
+ if not isinstance(y, list): y = [y]
42
+ return xs,y
43
+ return [tokenize_example(*e) for e in examples]
44
+
45
+ def __init__(self, tasks, testingTasks=[], cuda=False):
46
+ lexicon = {c
47
+ for t in tasks + testingTasks
48
+ for xs, y in self.tokenize(t.examples)
49
+ for c in reduce(lambda u, v: u + v, list(xs) + [y])}
50
+ self.recomputeTasks = True
51
+
52
+ super(LearnedFeatureExtractor, self).__init__(lexicon=list(lexicon),
53
+ H=64,
54
+ tasks=tasks,
55
+ bidirectional=True,
56
+ cuda=cuda)
57
+ self.MAXINPUTS = 8
58
+
59
+ def taskOfProgram(self, p, tp):
60
+ # Instantiate STRING w/ random words
61
+ p = p.visit(ConstantInstantiateVisitor.SINGLE)
62
+ return super(LearnedFeatureExtractor, self).taskOfProgram(p, tp)
63
+ except:
64
+ pass
65
+
66
+ ### COMPETITION CODE
67
+
68
+ def competeOnOneTask(checkpoint, task,
69
+ CPUs=8, timeout=3600, evaluationTimeout=0.0005):
70
+ if checkpoint.recognitionModel is not None:
71
+ recognizer = checkpoint.recognitionModel
72
+ challengeFrontiers, times, bestSearchTime = \
73
+ recognizer.enumerateFrontiers([task],
74
+ CPUs=CPUs,
75
+ maximumFrontier=1,
76
+ enumerationTimeout=timeout,
77
+ evaluationTimeout=evaluationTimeout)
78
+ else:
79
+ challengeFrontiers, times, bestSearchTimes = \
80
+ multicoreEnumeration(checkpoint.grammars[-1], [task],
81
+ CPUs=CPUs,
82
+ maximumFrontier=1,
83
+ enumerationTimeout=timeout,
84
+ evaluationTimeout=evaluationTimeout)
85
+ if len(times) == 0: return None, task
86
+ assert len(times) == 1
87
+ return times[0], task
88
+
89
+
90
+
91
+ def sygusCompetition(checkpoints, tasks):
92
+ from pathos.multiprocessing import Pool
93
+ import datetime
94
+
95
+ # map from task to list of search times, one for each checkpoint.
96
+ # search time will be None if it is not solved
97
+ searchTimes = {t: [] for t in tasks}
98
+
99
+ CPUs = int(8/len(checkpoints))
100
+ maxWorkers = int(numberOfCPUs()/CPUs)
101
+ workers = Pool(maxWorkers)
102
+ eprint(f"You gave me {len(checkpoints)} checkpoints to ensemble. Each checkpoint will get {CPUs} CPUs. Creating a pool of {maxWorkers} worker processes.")
103
+ timeout = 3600
104
+
105
+
106
+ promises = []
107
+ for t in tasks:
108
+ for checkpoint in checkpoints:
109
+ promise = workers.apply_async(competeOnOneTask,
110
+ (checkpoint,t),
111
+ {"CPUs": CPUs,
112
+ "timeout": timeout})
113
+ promises.append(promise)
114
+ eprint(f"Queued {len(promises)} jobs.")
115
+ for promise in promises:
116
+ dt, task = promise.get()
117
+ if dt is not None:
118
+ searchTimes[task].append(dt)
119
+
120
+ searchTimes = {t: min(ts) if len(ts) > 0 else None
121
+ for t,ts in searchTimes.items()}
122
+
123
+ fn = "experimentOutputs/text_competition_%s.p"%(datetime.datetime.now().isoformat())
124
+ with open(fn,"wb") as handle:
125
+ pickle.dump(searchTimes, handle)
126
+ eprint()
127
+
128
+ hits = sum( t is not None for t in searchTimes.values() )
129
+ total = len(searchTimes)
130
+ percentage = 100*hits/total
131
+ eprint("Hits %d/%d = %f\n"%(hits, total, percentage))
132
+ eprint()
133
+ eprint("Exported competition results to",fn)
134
+
135
+
136
+
137
+ def text_options(parser):
138
+ parser.add_argument(
139
+ "--showTasks",
140
+ action="store_true",
141
+ default=False,
142
+ help="show the training test and challenge tasks and then exit")
143
+ parser.add_argument(
144
+ "--trainChallenge",
145
+ action="store_true",
146
+ default=False,
147
+ help="Incorporate a random 50% of the challenge problems into the training set")
148
+ parser.add_argument(
149
+ "--onlyChallenge",
150
+ action="store_true",
151
+ default=False,
152
+ help="Only train on challenge problems and have testing problems.")
153
+ parser.add_argument(
154
+ "--latest",
155
+ action="store_true",
156
+ default=False,
157
+ help="evaluate on latest sygus problems rather than problems used in ec2 paper")
158
+ parser.add_argument(
159
+ "--noMap", action="store_true", default=False,
160
+ help="Disable built-in map primitive")
161
+ parser.add_argument(
162
+ "--noLength", action="store_true", default=False,
163
+ help="Disable built-in length primitive")
164
+ parser.add_argument(
165
+ "--noUnfold", action="store_true", default=False,
166
+ help="Disable built-in unfold primitive")
167
+ parser.add_argument(
168
+ "--compete",
169
+ nargs='+',
170
+ default=None,
171
+ type=str,
172
+ help="Do a simulated sygus competition (1hr+8cpus/problem) on the sygus tasks, restoring from provided checkpoint(s). If multiple checkpoints are provided, then we ensemble the models.")
173
+
174
+
175
+ def main(arguments):
176
+ """
177
+ Takes the return value of the `commandlineArguments()` function as input and
178
+ trains/tests the model on manipulating sequences of text.
179
+ """
180
+
181
+ tasks = makeTasks()
182
+ eprint("Generated", len(tasks), "tasks")
183
+
184
+ for t in tasks:
185
+ t.mustTrain = False
186
+
187
+ test, train = testTrainSplit(tasks, 1.)
188
+ eprint("Split tasks into %d/%d test/train" % (len(test), len(train)))
189
+
190
+ latest = arguments.pop("latest")
191
+ challenge, challengeCheating = loadPBETasks("data/sygus" if latest else "PBE_Strings_Track")
192
+ eprint("Got %d challenge PBE tasks" % len(challenge))
193
+
194
+ if arguments.pop('trainChallenge'):
195
+ challengeTest, challengeTrain = testTrainSplit(challenge, 0.5)
196
+ challenge = challengeTest
197
+ train += challengeTrain
198
+ eprint(
199
+ "Incorporating %d (50%%) challenge problems into the training set." %
200
+ (len(challengeTrain)),
201
+ "We will evaluate on the held out challenge problems.",
202
+ "This makes a total of %d training problems." %
203
+ len(train))
204
+
205
+ if arguments.pop('onlyChallenge'):
206
+ train = challenge
207
+ test = []
208
+ challenge = []
209
+ eprint("Training only on sygus problems.")
210
+
211
+
212
+ ConstantInstantiateVisitor.SINGLE = \
213
+ ConstantInstantiateVisitor(list(map(list, list({tuple([c for c in s])
214
+ for t in test + train + challenge
215
+ for s in t.stringConstants}))))
216
+
217
+ haveLength = not arguments.pop("noLength")
218
+ haveMap = not arguments.pop("noMap")
219
+ haveUnfold = not arguments.pop("noUnfold")
220
+ eprint(f"Including map as a primitive? {haveMap}")
221
+ eprint(f"Including length as a primitive? {haveLength}")
222
+ eprint(f"Including unfold as a primitive? {haveUnfold}")
223
+ baseGrammar = Grammar.uniform(primitives + [p
224
+ for p in bootstrapTarget()
225
+ if (p.name != "map" or haveMap) and \
226
+ (p.name != "unfold" or haveUnfold) and \
227
+ (p.name != "length" or haveLength)])
228
+ challengeGrammar = baseGrammar # Grammar.uniform(targetTextPrimitives)
229
+
230
+ evaluationTimeout = 0.0005
231
+ # We will spend 10 minutes on each challenge problem
232
+ challengeTimeout = 10 * 60
233
+
234
+ for t in train + test + challenge:
235
+ t.maxParameters = 2
236
+
237
+ if arguments.pop("showTasks"):
238
+ for source, ts in [("train",tasks),("test",test),("challenge",challenge)]:
239
+ print(source,"tasks:")
240
+ for t in ts:
241
+ print(t.name)
242
+ for xs, y in t.examples:
243
+ xs = ['"' + "".join(x) + '"' for x in xs]
244
+ y = "".join(y) if isinstance(y,list) else y
245
+ print('f(%s) = "%s"' % (", ".join(xs), y))
246
+ print("\t{%s}" % (t.stringConstants))
247
+ print()
248
+ sys.exit(0)
249
+
250
+
251
+ competitionCheckpoints = arguments.pop("compete")
252
+ if competitionCheckpoints:
253
+ checkpoints = []
254
+ for competitionCheckpoint in competitionCheckpoints:
255
+ with open(competitionCheckpoint, 'rb') as handle:
256
+ checkpoints.append(dill.load(handle))
257
+ sygusCompetition(checkpoints, challenge)
258
+ sys.exit(0)
259
+
260
+ timestamp = datetime.datetime.now().isoformat()
261
+ outputDirectory = "experimentOutputs/text/%s"%timestamp
262
+ os.system("mkdir -p %s"%outputDirectory)
263
+
264
+ generator = ecIterator(baseGrammar, train,
265
+ testingTasks=test + challenge,
266
+ outputPrefix="%s/text"%outputDirectory,
267
+ evaluationTimeout=evaluationTimeout,
268
+ **arguments)
269
+ for result in generator:
270
+ pass
dreamcoder/domains/text/makeTextTasks.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.task import *
2
+ from dreamcoder.type import *
3
+ from dreamcoder.utilities import *
4
+
5
+ import random
6
+
7
+
8
+ def lcs(u, v):
9
+ # t[(n,m)] = length of longest common string ending at first
10
+ # n elements of u & first m elements of v
11
+ t = {}
12
+
13
+ for n in range(len(u) + 1):
14
+ for m in range(len(v) + 1):
15
+ if m == 0 or n == 0:
16
+ t[(n, m)] = 0
17
+ continue
18
+
19
+ if u[n - 1] == v[m - 1]:
20
+ t[(n, m)] = 1 + t[(n - 1, m - 1)]
21
+ else:
22
+ t[(n, m)] = 0
23
+ l, n, m = max((l, n, m) for (n, m), l in t.items())
24
+ return u[n - l:n]
25
+
26
+
27
+ delimiters = ['.', ',', ' ', '(', ')', '-']
28
+ characters = [chr(ord('a') + j)
29
+ for j in range(26)] + \
30
+ [chr(ord('A') + j)
31
+ for j in range(26)] + \
32
+ [str(j) for j in range(10)] + \
33
+ ['+']
34
+
35
+ WORDS = None
36
+
37
+
38
+ def randomDelimiter():
39
+ return random.choice(delimiters)
40
+
41
+
42
+ def randomCharacter():
43
+ return random.choice(characters)
44
+
45
+
46
+ def randomWord(minimum=1, predicate=None):
47
+ global WORDS
48
+ if WORDS is None:
49
+ tasks, cheating = loadPBETasks()
50
+ observations = {''.join(z)
51
+ for t in tasks
52
+ for xs, y in t.examples
53
+ for z in list(xs) + [y]}
54
+
55
+ def splitMany(s, ds):
56
+ if len(ds) == 0:
57
+ return [s]
58
+ d = ds[0]
59
+ ds = ds[1:]
60
+ s = [w
61
+ for z in s.split(d)
62
+ for w in splitMany(z, ds)
63
+ if len(w) > 0]
64
+ return s
65
+
66
+ WORDS = {w
67
+ for o in observations
68
+ for w in splitMany(o, delimiters)}
69
+ WORDS = list(sorted(list(WORDS)))
70
+
71
+ # a disproportionately large fraction of the words have length three
72
+ # the purpose of this is to decrease the number of 3-length words we have
73
+ while True:
74
+ if random.random() > 0.7:
75
+ candidate = random.choice([w for w in WORDS if len(w) >= minimum])
76
+ else:
77
+ candidate = random.choice(
78
+ [w for w in WORDS if len(w) >= minimum and len(w) != 3])
79
+ if predicate is None or predicate(candidate):
80
+ return candidate
81
+
82
+
83
+ def randomWords(ds, minimum=1, lb=2, ub=4):
84
+ words = [randomWord(minimum=minimum)
85
+ for _ in range(random.choice(range(lb, ub+1)))]
86
+ s = ""
87
+ for j,w in enumerate(words):
88
+ if j > 0:
89
+ s += random.choice(ds)
90
+ s += w
91
+ return s
92
+
93
+
94
+ def makeTasks():
95
+ import random
96
+ random.seed(9)
97
+
98
+ NUMBEROFEXAMPLES = 4
99
+
100
+ problems = []
101
+
102
+ def toList(s): return [c for c in s]
103
+ # Converts strings into a list of characters depending on the type
104
+
105
+ def preprocess(x):
106
+ if isinstance(x, tuple):
107
+ return tuple(preprocess(z) for z in x)
108
+ if isinstance(x, list):
109
+ return [preprocess(z) for z in x]
110
+ if isinstance(x, str):
111
+ return [c for c in x]
112
+ if isinstance(x, bool):
113
+ return x
114
+ assert False
115
+
116
+ def problem(n, examples, needToTrain=False):
117
+ task = Task(n, guess_arrow_type(examples),
118
+ [(preprocess(x),
119
+ preprocess(y))
120
+ for x, y in examples])
121
+ task.mustTrain = True
122
+ problems.append(task)
123
+
124
+ for d1, d2 in randomPermutation(crossProduct(delimiters, delimiters))[
125
+ :len(delimiters) * 2]:
126
+ if d1 != d2:
127
+ problem("Replace '%s' w/ '%s'" % (d1, d2),
128
+ [((x,), x.replace(d1, d2))
129
+ for _ in range(NUMBEROFEXAMPLES)
130
+ for x in [randomWords(d1)]],
131
+ needToTrain=False)
132
+ for d in delimiters:
133
+ problem("drop first word delimited by '%s'" % d,
134
+ [((x,), d.join(x.split(d)[1:]))
135
+ for _ in range(NUMBEROFEXAMPLES)
136
+ for x in [randomWords(d)]],
137
+ needToTrain=True)
138
+ for n in [0, 1, -1]:
139
+ problem("nth (n=%d) word delimited by '%s'" % (n, d),
140
+ [((x,), x.split(d)[n])
141
+ for _ in range(NUMBEROFEXAMPLES)
142
+ for x in [randomWords(d)]],
143
+ needToTrain=True)
144
+ for d1 in delimiters:
145
+ problem("Append two words delimited by '%s'" % (d1),
146
+ [((x, y), x + d1 + y)
147
+ for _ in range(NUMBEROFEXAMPLES)
148
+ for x in [randomWord()]
149
+ for y in [randomWord()]],
150
+ needToTrain=True)
151
+ for d1, d2 in randomPermutation(
152
+ crossProduct(
153
+ delimiters, delimiters))[
154
+ :len(delimiters)]:
155
+ problem("Append two words delimited by '%s%s'" % (d1, d2),
156
+ [((x, y), x + d1 + d2 + y)
157
+ for _ in range(NUMBEROFEXAMPLES)
158
+ for x in [randomWord()]
159
+ for y in [randomWord()]],
160
+ needToTrain=True)
161
+ for n in range(1, 6):
162
+ problem("Drop last %d characters" % n,
163
+ [((x,), x[:-n])
164
+ for _ in range(NUMBEROFEXAMPLES)
165
+ for x in [randomWord(minimum=n)]],
166
+ needToTrain=True)
167
+ if n > 1:
168
+ problem("Take first %d characters" % n,
169
+ [((x,), x[:n])
170
+ for _ in range(NUMBEROFEXAMPLES)
171
+ for x in [randomWord(minimum=n)]],
172
+ needToTrain=True)
173
+ for d1, d2 in randomPermutation(
174
+ crossProduct(
175
+ delimiters, delimiters))[
176
+ :len(delimiters)]:
177
+ problem("Extract word delimited by '%s' - '%s'" % (d1, d2),
178
+ [((a + d1 + b + d2 + c + d + e,), b)
179
+ for _ in range(int(NUMBEROFEXAMPLES / 2))
180
+ for d in [d1, d2]
181
+ for a in [randomWord()]
182
+ for b in [randomWord()]
183
+ for c in [randomWord()]
184
+ for e in [randomWord()]],
185
+ needToTrain=True)
186
+
187
+ for n in range(len(delimiters)):
188
+ problem("First letters of words (%s)" % ("I" * (1 + n)),
189
+ [((x,), "".join(map(lambda z: z[0], x.split(' '))))
190
+ for _ in range(NUMBEROFEXAMPLES)
191
+ for x in [randomWords(' ')]
192
+ ],
193
+ needToTrain=True)
194
+
195
+ for d in delimiters:
196
+ problem("Take first character and append '%s'" % d,
197
+ [((x,), x[0] + d)
198
+ for _ in range(NUMBEROFEXAMPLES)
199
+ for x in [randomWord()]],
200
+ needToTrain=True)
201
+
202
+ for n in range(len(delimiters)):
203
+ problem("Abbreviate separate words (%s)" % ("I" * (n + 1)),
204
+ [((x, y), "%s.%s." % (x[0], y[0]))
205
+ for _ in range(NUMBEROFEXAMPLES)
206
+ for y in [randomWord()]
207
+ for x in [randomWord()]])
208
+ d = delimiters[n]
209
+ problem("Abbreviate words separated by '%s'" % d,
210
+ [((x + d + y,), "%s.%s." % (x[0], y[0]))
211
+ for _ in range(NUMBEROFEXAMPLES)
212
+ for y in [randomWord()]
213
+ for x in [randomWord()]])
214
+
215
+ for n in range(len(delimiters)):
216
+ problem("Append 2 strings (%s)" % ('I' * (n + 1)),
217
+ [((x, y), x + y)
218
+ for _ in range(NUMBEROFEXAMPLES)
219
+ for y in [randomWord()]
220
+ for x in [randomWord()]],
221
+ needToTrain=True)
222
+
223
+ for n in range(len(delimiters)):
224
+ w = randomWord(minimum=3)
225
+ problem("Prepend '%s'" % w,
226
+ [((x,), w + x)
227
+ for _ in range(NUMBEROFEXAMPLES)
228
+ for x in [randomWord()]])
229
+ w = randomWord(minimum=3)
230
+ problem("Append '%s'" % w,
231
+ [((x,), x + w)
232
+ for _ in range(NUMBEROFEXAMPLES)
233
+ for x in [randomWord()]])
234
+ w = randomWord(minimum=3)
235
+ problem("Prepend '%s' to first word" % w,
236
+ [((x + ' ' + y,), w + x)
237
+ for _ in range(NUMBEROFEXAMPLES)
238
+ for x in [randomWord()]
239
+ for y in [randomWord()]])
240
+
241
+ for n in range(1,6):
242
+ problem("parentheses around a single word (%s)"%('I'*n),
243
+ [((w,),"(%s)"%w)
244
+ for _ in range(NUMBEROFEXAMPLES)
245
+ for w in [randomWord()] ])
246
+ problem("parentheses around first word",
247
+ [((w + " " + s,),"(%s)"%w)
248
+ for _ in range(NUMBEROFEXAMPLES)
249
+ for w in [randomWord()]
250
+ for s in [randomWords(" ")] ])
251
+ problem("parentheses around second word",
252
+ [((s,), "(%s)"%(s.split(" ")[1]))
253
+ for _ in range(NUMBEROFEXAMPLES)
254
+ for s in [randomWords(" ")] ])
255
+
256
+ allowed = [d for d in delimiters if d not in "()"]
257
+ for d1,d2 in randomPermutation(crossProduct(allowed, allowed))[:len(delimiters)]:
258
+ problem("parentheses around word delimited by '%s' & '%s'"%(d1,d2),
259
+ [((prefix + d1 + word + d2 + suffix,),
260
+ prefix + d1 + '(' + word + ')' + d2 + suffix)
261
+ for _ in range(NUMBEROFEXAMPLES)
262
+ for prefix in [randomWords("", lb=0, ub=1)]
263
+ for suffix in [randomWords(allowed, ub=2, lb=1)]
264
+ for word in [randomWord()] ])
265
+
266
+ for n in range(7):
267
+ w = randomWord(minimum=3)
268
+ problem("ensure suffix `%s`"%w,
269
+ [ ((s + (w if f else ""),), s + w)
270
+ for _ in range(NUMBEROFEXAMPLES)
271
+ for s in [randomWords(" ")]
272
+ for f in [random.choice([True,False])] ])
273
+
274
+
275
+ for p in problems:
276
+ guessConstantStrings(p)
277
+
278
+ return problems
279
+
280
+
281
+ def loadPBETasks(directory="PBE_Strings_Track"):
282
+ """
283
+ Processes sygus benchmarks into task objects
284
+ For these benchmarks, all of the constant strings are given to us.
285
+ In a sense this is cheating
286
+ Returns (tasksWithoutCheating, tasksWithCheating).
287
+ NB: Results in paper are done without "cheating"
288
+ """
289
+ import os
290
+ from sexpdata import loads, Symbol
291
+
292
+ def findStrings(s):
293
+ if isinstance(s, list):
294
+ return [y
295
+ for x in s
296
+ for y in findStrings(x)]
297
+ if isinstance(s, str):
298
+ return [s]
299
+ return []
300
+
301
+ def explode(s):
302
+ return [c for c in s]
303
+
304
+ tasks = []
305
+ cheatingTasks = []
306
+ for f in os.listdir(directory):
307
+ if not f.endswith('.sl'):
308
+ continue
309
+ with open(directory + "/" + f, "r") as handle:
310
+ message = "(%s)" % (handle.read())
311
+
312
+ expression = loads(message)
313
+
314
+ constants = []
315
+ name = f
316
+ examples = []
317
+ declarative = False
318
+ for e in expression:
319
+ if len(e) == 0:
320
+ continue
321
+ if e[0] == Symbol('constraint'):
322
+ e = e[1]
323
+ assert e[0] == Symbol('=')
324
+ inputs = e[1]
325
+ assert inputs[0] == Symbol('f')
326
+ inputs = inputs[1:]
327
+ output = e[2]
328
+ examples.append((inputs, output))
329
+ elif e[0] == Symbol('synth-fun'):
330
+ if e[1] == Symbol('f'):
331
+ constants += findStrings(e)
332
+ else:
333
+ declarative = True
334
+ break
335
+ if declarative: continue
336
+
337
+ examples = list({(tuple(xs), y) for xs, y in examples})
338
+
339
+ task = Task(name, arrow(*[tstr] * (len(examples[0][0]) + 1)),
340
+ [(tuple(map(explode, xs)), explode(y))
341
+ for xs, y in examples])
342
+ cheat = task
343
+
344
+ tasks.append(task)
345
+ cheatingTasks.append(cheat)
346
+
347
+ for p in tasks:
348
+ guessConstantStrings(p)
349
+ return tasks, cheatingTasks
350
+
351
+
352
+ def guessConstantStrings(task):
353
+ if task.request.returns() == tlist(tcharacter):
354
+ examples = task.examples
355
+ guesses = {}
356
+ N = 10
357
+ T = 2
358
+ for n in range(min(N, len(examples))):
359
+ for m in range(n + 1, min(N, len(examples))):
360
+ y1 = examples[n][1]
361
+ y2 = examples[m][1]
362
+ l = ''.join(lcs(y1, y2))
363
+ if len(l) > 2:
364
+ guesses[l] = guesses.get(l, 0) + 1
365
+
366
+ task.stringConstants = [g for g, f in guesses.items()
367
+ if f >= T]
368
+ else:
369
+ task.stringConstants = []
370
+
371
+
372
+ task.BIC = 1.
373
+ task.maxParameters = 1
374
+
375
+ task.specialTask = ("stringConstant",
376
+ {"maxParameters": task.maxParameters,
377
+ "stringConstants": task.stringConstants})
378
+
379
+
380
+ if __name__ == "__main__":
381
+ challenge, _ = loadPBETasks("data/sygus")
382
+
383
+ tasks = makeTasks()
384
+ print(len(tasks), "synthetic tasks")
385
+ tasks = []
386
+ for t in tasks + challenge:
387
+ print(t.name)
388
+ for xs, y in t.examples:
389
+ xs = ['"' + "".join(x) + '"' for x in xs]
390
+ y = "".join(y)
391
+ print('f(%s) = "%s"' % (", ".join(xs), y))
392
+ print("\t{%s}" % (t.stringConstants))
393
+ print()
394
+ assert False
395
+ # def maximumLength(x):
396
+ # if isinstance(x,list):
397
+ # return max([len(x)] + map(maximumLength,x))
398
+ # return 1
399
+
400
+ # print max(maximumLength(z) for t in tasks
401
+ # for (x,),y in t.examples
402
+ # for z in [x,y] )
403
+
404
+ if len(sys.argv) > 1 and "json" in sys.argv[1]:
405
+ import json
406
+ tasks = makeTasks()
407
+ obj = [t.as_json_dict() for t in tasks]
408
+ json.dump(obj, sys.stdout)
409
+ else:
410
+ as_tex = len(sys.argv) > 1 and "tex" in sys.argv[1]
411
+ for t in tasks:
412
+ print(t.name)
413
+ print(t.request)
414
+ if as_tex:
415
+ print("""\\begin{tabular}{ll}
416
+ \\toprule Input&Output\\\\\\midrule
417
+ %s
418
+ \\\\\\bottomrule
419
+ \\end{tabular}""" % (" \\\\\n ".join(x[0] + " & " + y for x, y in t.examples)))
420
+ else:
421
+ for x, y in t.examples:
422
+ print(x[0], '\t', y)
423
+ print()
424
+ print(len(tasks), "tasks")
dreamcoder/domains/text/textPrimitives.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.program import *
2
+ from dreamcoder.domains.text.makeTextTasks import delimiters
3
+
4
+ def _isUpper(x): return x.isupper()
5
+
6
+ def _increment(x): return x + 1
7
+
8
+
9
+ def _decrement(x): return x - 1
10
+
11
+
12
+ def _lower(x): return x.lower()
13
+
14
+
15
+ def _upper(x): return x.upper()
16
+
17
+
18
+ def _capitalize(x): return x.capitalize()
19
+
20
+
21
+ def _append(x): return lambda y: x + y
22
+
23
+
24
+ def _slice(x): return lambda y: lambda s: s[x:y]
25
+
26
+
27
+ def _index(n): return lambda x: x[n]
28
+
29
+
30
+ def _map(f): return lambda x: list(map(f, x))
31
+
32
+
33
+ def _find(pattern): return lambda s: s.index(pattern)
34
+
35
+
36
+ def _replace(original): return lambda replacement: lambda target: target.replace(
37
+ original, replacement)
38
+
39
+
40
+ def _split(delimiter): return lambda s: s.split(delimiter)
41
+
42
+
43
+ def _join(delimiter): return lambda ss: delimiter.join(ss)
44
+
45
+
46
+ def _identity(x): return x
47
+ #def _reverse(x): return x[::-1]
48
+
49
+
50
+ def _strip(x): return x.strip()
51
+
52
+
53
+ def _eq(x): return lambda y: x == y
54
+
55
+
56
+ specialCharacters = {' ': 'SPACE',
57
+ ')': 'RPAREN',
58
+ '(': 'LPAREN'}
59
+
60
+ primitives = [
61
+ Primitive("char-eq?", arrow(tcharacter, tcharacter, tboolean), _eq),
62
+ Primitive("STRING", tstr, None)
63
+ ] + [Primitive("'%s'" % d, tcharacter, d) for d in delimiters if d not in specialCharacters] + \
64
+ [Primitive(name, tcharacter, value) for value, name in specialCharacters.items()]
65
+
66
+
67
+ def _cons(x): return lambda y: [x] + y
68
+
69
+
70
+ def _car(x): return x[0]
71
+
72
+
73
+ def _cdr(x): return x[1:]
74
+
75
+
76
+ targetTextPrimitives = [
77
+ Primitive("take-word", arrow(tcharacter, tstr, tstr), None),
78
+ Primitive("drop-word", arrow(tcharacter, tstr, tstr), None),
79
+ Primitive("append", arrow(tlist(t0), tlist(t0), tlist(t0)), None),
80
+ Primitive("abbreviate", arrow(tstr, tstr), None),
81
+ Primitive("last-word", arrow(tcharacter, tstr, tstr), None),
82
+ Primitive("replace-character", arrow(tcharacter, tcharacter, tstr, tstr), None),
83
+ ] + primitives + [
84
+ Primitive("empty", tlist(t0), []),
85
+ Primitive("cons", arrow(t0, tlist(t0), tlist(t0)), _cons),
86
+ Primitive("car", arrow(tlist(t0), t0), _car),
87
+ Primitive("cdr", arrow(tlist(t0), tlist(t0)), _cdr)]
dreamcoder/domains/tower/__init__.py ADDED
File without changes
dreamcoder/domains/tower/main.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.dreamcoder import *
2
+
3
+ from dreamcoder.domains.tower.towerPrimitives import primitives, new_primitives, animateTower
4
+ from dreamcoder.domains.tower.makeTowerTasks import *
5
+ from dreamcoder.domains.tower.tower_common import renderPlan, towerLength, centerTower
6
+ from dreamcoder.utilities import *
7
+
8
+ import os
9
+ import datetime
10
+
11
+ try: #pypy will fail
12
+ from dreamcoder.recognition import variable
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ class Flatten(nn.Module):
17
+ def __init__(self):
18
+ super(Flatten, self).__init__()
19
+
20
+ def forward(self, x):
21
+ return x.view(x.size(0), -1)
22
+
23
+
24
+ class TowerCNN(nn.Module):
25
+ special = 'tower'
26
+
27
+ def __init__(self, tasks, testingTasks=[], cuda=False, H=64):
28
+ super(TowerCNN, self).__init__()
29
+ self.CUDA = cuda
30
+ self.recomputeTasks = True
31
+
32
+ self.outputDimensionality = H
33
+ def conv_block(in_channels, out_channels):
34
+ return nn.Sequential(
35
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
36
+ # nn.BatchNorm2d(out_channels),
37
+ nn.ReLU(),
38
+ nn.MaxPool2d(2)
39
+ )
40
+
41
+ self.inputImageDimension = 256
42
+ self.resizedDimension = 64
43
+ assert self.inputImageDimension % self.resizedDimension == 0
44
+
45
+ # channels for hidden
46
+ hid_dim = 64
47
+ z_dim = 64
48
+
49
+ self.encoder = nn.Sequential(
50
+ conv_block(6, hid_dim),
51
+ conv_block(hid_dim, hid_dim),
52
+ conv_block(hid_dim, hid_dim),
53
+ conv_block(hid_dim, z_dim),
54
+ Flatten()
55
+ )
56
+
57
+ self.outputDimensionality = 1024
58
+
59
+ if cuda:
60
+ self.CUDA=True
61
+ self.cuda() # I think this should work?
62
+
63
+ def forward(self, v, v2=None):
64
+ """v: tower to build. v2: image of tower we have built so far"""
65
+ # insert batch if it is not already there
66
+ if len(v.shape) == 3:
67
+ v = np.expand_dims(v, 0)
68
+ inserted_batch = True
69
+ if v2 is not None:
70
+ assert len(v2.shape) == 3
71
+ v2 = np.expand_dims(v2, 0)
72
+ elif len(v.shape) == 4:
73
+ inserted_batch = False
74
+ pass
75
+ else:
76
+ assert False, "v has the shape %s"%(str(v.shape))
77
+
78
+ if v2 is None: v2 = np.zeros(v.shape)
79
+
80
+ v = np.concatenate((v,v2), axis=3)
81
+ v = np.transpose(v,(0,3,1,2))
82
+ assert v.shape == (v.shape[0], 6,self.inputImageDimension,self.inputImageDimension)
83
+ v = variable(v, cuda=self.CUDA).float()
84
+ window = int(self.inputImageDimension/self.resizedDimension)
85
+ v = F.avg_pool2d(v, (window,window))
86
+ #showArrayAsImage(np.transpose(v.data.numpy()[0,:3,:,:],[1,2,0]))
87
+ v = self.encoder(v)
88
+ if inserted_batch:
89
+ return v.view(-1)
90
+ else:
91
+ return v
92
+
93
+ def featuresOfTask(self, t, t2=None): # Take a task and returns [features]
94
+ return self(t.getImage(),
95
+ None if t2 is None else t2.getImage(drawHand=True))
96
+
97
+ def featuresOfTasks(self, ts, t2=None): # Take a task and returns [features]
98
+ """Takes the goal first; optionally also takes the current state second"""
99
+ if t2 is None:
100
+ pass
101
+ elif isinstance(t2, Task):
102
+ assert False
103
+ #t2 = np.array([t2.getImage(drawHand=True)]*len(ts))
104
+ elif isinstance(t2, list):
105
+ t2 = np.array([t.getImage(drawHand=True) if t else np.zeros((self.inputImageDimension,
106
+ self.inputImageDimension,
107
+ 3))
108
+ for t in t2])
109
+ else:
110
+ assert False
111
+
112
+ return self(np.array([t.getImage() for t in ts]),
113
+ t2)
114
+
115
+ def taskOfProgram(self, p, t,
116
+ lenient=False):
117
+ try:
118
+ pl = executeTower(p,0.05)
119
+ if pl is None or (not lenient and len(pl) == 0): return None
120
+ if len(pl) > 100 or towerLength(pl) > 360: return None
121
+
122
+ t = SupervisedTower("tower dream", p)
123
+ return t
124
+ except Exception as e:
125
+ return None
126
+ except: pass
127
+
128
+
129
+
130
+ def tower_options(parser):
131
+ parser.add_argument("--tasks",
132
+ choices=["old","new"],
133
+ default="old")
134
+ parser.add_argument("--visualize",
135
+ default=None, type=str)
136
+ parser.add_argument("--solutions",
137
+ default=None, type=str)
138
+ parser.add_argument("--split",
139
+ default=1., type=float)
140
+ parser.add_argument("--dream",
141
+ default=None, type=str)
142
+ parser.add_argument("--primitives",
143
+ default="old", type=str,
144
+ choices=["new", "old"])
145
+
146
+
147
+ def dreamOfTowers(grammar, prefix, N=250, make_montage=True):
148
+ request = arrow(ttower,ttower)
149
+ randomTowers = [tuple(centerTower(t))
150
+ for _ in range(N)
151
+ for program in [grammar.sample(request,
152
+ maximumDepth=12,
153
+ maxAttempts=100)]
154
+ if program is not None
155
+ for t in [executeTower(program, timeout=0.5) or []]
156
+ if len(t) >= 1 and len(t) < 100 and towerLength(t) <= 360.]
157
+ matrix = [renderPlan(p,Lego=True,pretty=True)
158
+ for p in randomTowers]
159
+
160
+ # Only visualize if it has something to visualize.
161
+ if len(matrix) > 0:
162
+ import scipy.misc
163
+ if make_montage:
164
+ matrix = montage(matrix)
165
+ scipy.misc.imsave('%s.png'%prefix, matrix)
166
+ else:
167
+ for n,i in enumerate(matrix):
168
+ scipy.misc.imsave(f'{prefix}/{n}.png', i)
169
+ else:
170
+ eprint("Tried to visualize dreams, but none to visualize.")
171
+
172
+
173
+ def visualizePrimitives(primitives, fn=None):
174
+ from itertools import product
175
+ #from pylab import imshow,show
176
+
177
+ from dreamcoder.domains.tower.towerPrimitives import _left,_right,_loop,_embed,_empty_tower,TowerState
178
+ _13 = Program.parse("1x3").value
179
+ _31 = Program.parse("3x1").value
180
+
181
+ r = lambda n,k: _right(2*n)(k)
182
+ l = lambda n,k: _left(2*n)(k)
183
+ _e = _embed
184
+ _lp = lambda n,b,k: _loop(n)(b)(k)
185
+ _arch = lambda k: l(1,_13(r(2,_13(l(1,_31(k))))))
186
+ _tallArch = lambda h,z,k: _lp(h, lambda _: _13(r(2,_13(l(2,z)))),
187
+ r(1,_31(k)))
188
+
189
+ matrix = []
190
+ for p in primitives:
191
+ if not p.isInvented: continue
192
+ eprint(p,":",p.tp)
193
+ t = p.tp
194
+ if t.returns() != ttower: continue
195
+
196
+ def argumentChoices(t):
197
+ if t == ttower:
198
+ return [_empty_tower]
199
+ elif t == tint:
200
+ return list(range(5))
201
+ elif t == arrow(ttower,ttower):
202
+ return [_arch,_13,_31]
203
+ else:
204
+ return []
205
+
206
+ ts = []
207
+ for arguments in product(*[argumentChoices(t) for t in t.functionArguments() ]):
208
+ t = p.evaluate([])
209
+ for a in arguments: t = t(a)
210
+ t = t(TowerState())[1]
211
+ ts.append(t)
212
+
213
+ if ts == []: continue
214
+
215
+ matrix.append([renderPlan(p,pretty=True)
216
+ for p in ts])
217
+
218
+ # Only visualize if it has something to visualize.
219
+ if len(matrix) > 0:
220
+ matrix = montageMatrix(matrix)
221
+ # imshow(matrix)
222
+
223
+ import scipy.misc
224
+ scipy.misc.imsave(fn, matrix)
225
+ # show()
226
+ else:
227
+ eprint("Tried to visualize primitives, but none to visualize.")
228
+
229
+ def animateSolutions(checkpoint):
230
+ with open(checkpoint,"rb") as handle: result = dill.load(handle)
231
+ for n,f in enumerate(result.taskSolutions.values()):
232
+ animateTower(f"/tmp/tower_animation_{n}",f.bestPosterior.program)
233
+
234
+ def visualizeSolutions(solutions, export, tasks=None):
235
+
236
+ if tasks is None:
237
+ tasks = list(solutions.keys())
238
+ tasks.sort(key=lambda t: len(t.plan))
239
+
240
+ matrix = []
241
+ for t in tasks:
242
+ i = renderPlan(centerTower(t.plan),pretty=True,Lego=True)
243
+ if solutions[t].empty: i = i/3.
244
+ matrix.append(i)
245
+
246
+ # Only visualize if it has something to visualize.
247
+ if len(matrix) > 0:
248
+ matrix = montage(matrix)
249
+ import scipy.misc
250
+ scipy.misc.imsave(export, matrix)
251
+ else:
252
+ eprint("Tried to visualize solutions, but none to visualize.")
253
+
254
+
255
+ def main(arguments):
256
+ """
257
+ Takes the return value of the `commandlineArguments()` function as input and
258
+ trains/tests the model on a set of tower-building tasks.
259
+ """
260
+
261
+ # The below global statement is required since primitives is modified within main().
262
+ # TODO(lcary): use a function call to retrieve and declare primitives instead.
263
+ global primitives
264
+
265
+ import scipy.misc
266
+
267
+ g0 = Grammar.uniform({"new": new_primitives,
268
+ "old": primitives}[arguments.pop("primitives")],
269
+ continuationType=ttower)
270
+
271
+ checkpoint = arguments.pop("visualize")
272
+ if checkpoint is not None:
273
+ with open(checkpoint,'rb') as handle:
274
+ primitives = pickle.load(handle).grammars[-1].primitives
275
+ visualizePrimitives(primitives)
276
+ sys.exit(0)
277
+ checkpoint = arguments.pop("solutions")
278
+ if checkpoint is not None:
279
+ with open(checkpoint,'rb') as handle:
280
+ solutions = pickle.load(handle).taskSolutions
281
+ visualizeSolutions(solutions,
282
+ checkpoint + ".solutions.png")
283
+ animateSolutions(checkpoint)
284
+ sys.exit(0)
285
+ checkpoint = arguments.pop("dream")
286
+ if checkpoint is not None:
287
+ with open(checkpoint,'rb') as handle:
288
+ g = pickle.load(handle).grammars[-1]
289
+ os.system("mkdir -p data/tower_dreams")
290
+ dreamOfTowers(g, "data/tower_dreams", make_montage=False)
291
+ sys.exit(0)
292
+
293
+
294
+ tasks = arguments.pop("tasks")
295
+ if tasks == "new":
296
+ tasks = makeSupervisedTasks()
297
+ elif tasks == "old":
298
+ tasks = makeOldSupervisedTasks()
299
+ else: assert False
300
+
301
+ test, train = testTrainSplit(tasks, arguments.pop("split"))
302
+ eprint("Split %d/%d test/train" % (len(test), len(train)))
303
+
304
+ # Make a montage for the paper
305
+ shuffledTrain = list(train)
306
+ shuffledTest = list(test)
307
+ random.shuffle(shuffledTrain)
308
+ shuffledTrain = shuffledTrain + [None]*(60 - len(shuffledTrain))
309
+ random.shuffle(shuffledTest)
310
+ shuffledTest = shuffledTest + [None]*(60 - len(shuffledTest))
311
+ try:
312
+ SupervisedTower.exportMany("/tmp/every_tower.png",shuffledTrain + shuffledTest, shuffle=False, columns=10)
313
+ for j,task in enumerate(tasks):
314
+ task.exportImage(f"/tmp/tower_task_{j}.png")
315
+ for k,v in dSLDemo().items():
316
+ scipy.misc.imsave(f"/tmp/tower_dsl_{k}.png", v)
317
+ os.system(f"convert /tmp/tower_dsl_{k}.png -channel RGB -negate /tmp/tower_dsl_{k}.png")
318
+ except:
319
+ eprint("WARNING: can't export images. scipy needs to be an older version")
320
+
321
+
322
+ timestamp = datetime.datetime.now().isoformat()
323
+ outputDirectory = "experimentOutputs/towers/%s"%timestamp
324
+ os.system("mkdir -p %s"%outputDirectory)
325
+
326
+ os.system("mkdir -p data/tower_dreams_initial")
327
+ try:
328
+ dreamOfTowers(g0, "data/tower_dreams_initial", make_montage=False)
329
+ dreamOfTowers(g0, "%s/random_0"%outputDirectory)
330
+ except:
331
+ eprint("WARNING: can't export images. scipy needs to be an older version")
332
+
333
+ evaluationTimeout = 0.005
334
+ generator = ecIterator(g0, train,
335
+ testingTasks=test,
336
+ outputPrefix="%s/tower"%outputDirectory,
337
+ evaluationTimeout=evaluationTimeout,
338
+ **arguments)
339
+
340
+
341
+
342
+ for result in generator:
343
+ continue
344
+ iteration = len(result.learningCurve)
345
+ newTowers = [tuple(centerTower(executeTower(frontier.sample().program)))
346
+ for frontier in result.taskSolutions.values() if not frontier.empty]
347
+ try:
348
+ fn = '%s/solutions_%d.png'%(outputDirectory,iteration)
349
+ visualizeSolutions(result.taskSolutions, fn,
350
+ train)
351
+ eprint("Exported solutions to %s\n"%fn)
352
+ dreamOfTowers(result.grammars[-1],
353
+ '%s/random_%d'%(outputDirectory,iteration))
354
+ except ImportError:
355
+ eprint("Could not import required libraries for exporting towers.")
356
+ primitiveFilename = '%s/primitives_%d.png'%(outputDirectory, iteration)
357
+ visualizePrimitives(result.grammars[-1].primitives,
358
+ primitiveFilename)
359
+ eprint("Exported primitives to",primitiveFilename)
dreamcoder/domains/tower/makeTowerTasks.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.domains.tower.towerPrimitives import ttower, executeTower, _empty_tower, TowerState
2
+ from dreamcoder.domains.tower.tower_common import renderPlan
3
+ from dreamcoder.task import *
4
+
5
+
6
+ class SupervisedTower(Task):
7
+ def __init__(self, name, program, mustTrain=False):
8
+ if isinstance(program,str):
9
+ try:
10
+ program = parseTower(program)
11
+ except:
12
+ eprint("Parse failure:")
13
+ eprint(program)
14
+ assert False
15
+ self.original = program
16
+ plan = executeTower(program)
17
+ elif isinstance(program,Program):
18
+ self.original = program
19
+ plan = executeTower(program)
20
+ else:
21
+ plan = program
22
+ self.original = program
23
+ state, self.plan = program.evaluate([])(_empty_tower)(TowerState())
24
+ self.hand = state.hand
25
+ super(SupervisedTower, self).__init__(name, arrow(ttower,ttower), [],
26
+ features=[])
27
+ self.specialTask = ("supervisedTower",
28
+ {"plan": self.plan})
29
+ self.image = None
30
+ self.handImage = None
31
+ self.mustTrain = mustTrain
32
+
33
+ def getImage(self, drawHand=False, pretty=False):
34
+ if not drawHand:
35
+ if not pretty:
36
+ if self.image is not None: return self.image
37
+ self.image = renderPlan(self.plan, pretty=pretty)
38
+ return self.image
39
+ else:
40
+ return renderPlan(self.plan, pretty=True)
41
+ else:
42
+ if self.handImage is not None: return self.handImage
43
+ self.handImage = renderPlan(self.plan,
44
+ drawHand=self.hand,
45
+ pretty=pretty)
46
+ return self.handImage
47
+
48
+
49
+
50
+ # do not pickle the image
51
+ def __getstate__(self):
52
+ return self.specialTask, self.plan, self.request, self.cache, self.name, self.examples
53
+ def __setstate__(self, state):
54
+ self.specialTask, self.plan, self.request, self.cache, self.name, self.examples = state
55
+ self.image = None
56
+
57
+
58
+ def animate(self):
59
+ from pylab import imshow,show
60
+ a = renderPlan(self.plan)
61
+ imshow(a)
62
+ show()
63
+
64
+ @staticmethod
65
+ def showMany(ts):
66
+ from pylab import imshow,show
67
+ a = montage([renderPlan(t.plan, pretty=True, Lego=True, resolution=256,
68
+ drawHand=False)
69
+ for t in ts])
70
+ imshow(a)
71
+ show()
72
+
73
+ @staticmethod
74
+ def exportMany(f, ts, shuffle=True, columns=None):
75
+ import numpy as np
76
+
77
+ ts = list(ts)
78
+ if shuffle:
79
+ assert all( t is not None for t in ts )
80
+ random.shuffle(ts)
81
+ a = montage([renderPlan(t.plan, pretty=True, Lego=True, resolution=256) if t is not None \
82
+ else np.zeros((256,256,3))
83
+ for t in ts],
84
+ columns=columns)
85
+ import scipy.misc
86
+ scipy.misc.imsave(f, a)
87
+
88
+
89
+ def exportImage(self, f, pretty=True, Lego=True, drawHand=False):
90
+ a = renderPlan(self.plan,
91
+ pretty=pretty, Lego=Lego,
92
+ drawHand=t.hand if drawHand else None)
93
+ import scipy.misc
94
+ scipy.misc.imsave(f, a)
95
+
96
+ def logLikelihood(self, e, timeout=None):
97
+ from dreamcoder.domains.tower.tower_common import centerTower
98
+ yh = executeTower(e, timeout)
99
+ if yh is not None and centerTower(yh) == centerTower(self.plan): return 0.
100
+ return NEGATIVEINFINITY
101
+
102
+
103
+
104
+ def parseTower(s):
105
+ _13 = Program.parse("1x3")
106
+ _31 = Program.parse("3x1")
107
+ _r = Program.parse("right")
108
+ _l = Program.parse("left")
109
+ _addition = Program.parse("+")
110
+ _subtraction = Program.parse("-")
111
+ _lp = Program.parse("tower_loopM")
112
+ _e = Program.parse("tower_embed")
113
+
114
+ from sexpdata import loads, Symbol
115
+ s = loads(s)
116
+ def command(k, environment, continuation):
117
+ if k == Symbol("1x3") or k == Symbol("v"): return Application(_13, continuation)
118
+ if k == Symbol("3x1") or k == Symbol("h"): return Application(_31, continuation)
119
+ assert isinstance(k,list)
120
+ if k[0] == Symbol("r"): return Application(Application(_r, expression(k[1],environment)),continuation)
121
+ if k[0] == Symbol("l"): return Application(Application(_l, expression(k[1],environment)),continuation)
122
+ if k[0] == Symbol("for"):
123
+ v = k[1]
124
+ b = expression(k[2], environment)
125
+ newEnvironment = [None, v] + environment
126
+ body = block(k[3:], newEnvironment, Index(0))
127
+ return Application(Application(Application(_lp,b),
128
+ Abstraction(Abstraction(body))),
129
+ continuation)
130
+ if k[0] == Symbol("embed"):
131
+ body = block(k[1:], [None] + environment, Index(0))
132
+ return Application(Application(_e,Abstraction(body)),continuation)
133
+
134
+ assert False
135
+ def expression(e, environment):
136
+ for n, v in enumerate(environment):
137
+ if e == v: return Index(n)
138
+
139
+ if isinstance(e,int): return Program.parse(str(e))
140
+
141
+ assert isinstance(e,list)
142
+ if e[0] == Symbol('+'): return Application(Application(_addition, expression(e[1], environment)),
143
+ expression(e[2], environment))
144
+ if e[0] == Symbol('-'): return Application(Application(_subtraction, expression(e[1], environment)),
145
+ expression(e[2], environment))
146
+ assert False
147
+
148
+ def block(b, environment, continuation):
149
+ if len(b) == 0: return continuation
150
+ return command(b[0], environment, block(b[1:], environment, continuation))
151
+
152
+ try: return Abstraction(command(s, [], Index(0)))
153
+ except: return Abstraction(block(s, [], Index(0)))
154
+
155
+
156
+ def makeSupervisedTasks():
157
+ arches = [SupervisedTower("arch leg %d"%n,
158
+ "((for i %d v) (r 4) (for i %d v) (l 2) h)"%(n,n))
159
+ for n in range(1,9)
160
+ ]
161
+ archesStacks = [SupervisedTower("arch stack %d"%n,
162
+ """
163
+ (for i %d
164
+ v (r 4) v (l 2) h (l 2))
165
+ """%n)
166
+ for n in range(3,7) ]
167
+ Bridges = [SupervisedTower("bridge (%d) of arch %d"%(n,l),
168
+ """
169
+ (for j %d
170
+ (for i %d
171
+ v (r 4) v (l 4)) (r 2) h
172
+ (r 4))
173
+ """%(n,l))
174
+ for n in range(2,8)
175
+ for l in range(1,6)]
176
+ offsetArches = [SupervisedTower("bridge (%d) of arch, spaced %d"%(n,l),
177
+ """
178
+ (for j %d
179
+ (embed v (r 4) v (l 2) h )
180
+ (r %d))
181
+ """%(n,l),
182
+ mustTrain=n == 3)
183
+ for n,l in [(3,7),(4,8)]]
184
+ Josh = [SupervisedTower("Josh (%d)"%n,
185
+ """(for i %d
186
+ h (l 2) v (r 2) v (r 2) v (l 2) h (r 6))"""%n)
187
+ for n in range(1,7) ]
188
+
189
+ staircase1 = [SupervisedTower("R staircase %d"%n,
190
+ """
191
+ (for i %d (for j i
192
+ (embed v (r 4) v (l 2) h)) (r 6))
193
+ """%(n))
194
+ for n in range(3,8) ]
195
+ staircase2 = [SupervisedTower("L staircase %d"%n,
196
+ """
197
+ (for i %d (for j i
198
+ (embed v (r 4) v (l 2) h)) (l 6))
199
+ """%(n))
200
+ for n in range(3,8) ]
201
+ simpleLoops = [SupervisedTower("%s row %d, spacing %d"%(o,n,s),
202
+ """(for j %d %s (r %s))"""%(n,o,s),
203
+ mustTrain=True)
204
+ for o,n,s in [('h',4,7), ('v',5,3)] ]
205
+
206
+ pyramids = []
207
+ pyramids += [SupervisedTower("arch pyramid %d"%n,
208
+ """((for i %d (for j i (embed v (r 4) v (l 2) h)) (r 6))
209
+ (for i %d (for j (- %d i) (embed v (r 4) v (l 2) h)) (r 6)))"""%(n,n,n))
210
+ for n in range(2,6) ]
211
+ pyramids += [SupervisedTower("H pyramid %d"%n,
212
+ """((for i %d (for j i h) (r 6))
213
+ (for i %d (for j (- %d i) h) (r 6)))"""%(n,n,n))
214
+ for n in range(4,6) ]
215
+ # pyramids += [SupervisedTower("V pyramid %d"%n,
216
+ # """
217
+ # ((for i %d (for j i v) (r 2))
218
+ # (for i %d (for j (- %d i) v) (r 2)))
219
+ # """%(n,n,n))
220
+ # for n in range(4,8) ]
221
+ # pyramids += [SupervisedTower("V3 pyramid %d"%n,
222
+ # """
223
+ # ((for i %d (for j i v) (r 6))
224
+ # (for i %d (for j (- %d i) v) (r 6)))
225
+ # """%(n,n,n))
226
+ # for n in range(4,8) ]
227
+ pyramids += [SupervisedTower("H 1/2 pyramid %d"%n,
228
+ """
229
+ (for i %d
230
+ (r 6)
231
+ (embed
232
+ (for j i h (l 3))))
233
+ """%n)
234
+ for n in range(4,8) ]
235
+ pyramids += [SupervisedTower("arch 1/2 pyramid %d"%n,
236
+ """
237
+ (for i %d
238
+ (r 6)
239
+ (embed
240
+ (for j i (embed v (r 4) v (l 2) h) (l 3))))
241
+ """%n)
242
+ for n in range(2,8) ]
243
+ if False:
244
+ pyramids += [SupervisedTower("V 1/2 pyramid %d"%n,
245
+ """
246
+ (for i %d
247
+ (r 2)
248
+ (embed
249
+ (for j i v (l 1))))"""%(n))
250
+ for n in range(4,8) ]
251
+ bricks = [SupervisedTower("brickwall, %dx%d"%(w,h),
252
+ """(for j %d
253
+ (embed (for i %d h (r 6)))
254
+ (embed (r 3) (for i %d h (r 6))))"""%(h,w,w))
255
+ for w in range(3,7)
256
+ for h in range(1,6) ]
257
+ aqueducts = [SupervisedTower("aqueduct: %dx%d"%(w,h),
258
+ """(for j %d
259
+ %s (r 4) %s (l 2) h (l 2) v (r 4) v (l 2) h (r 4))"""%
260
+ (w, "v "*h, "v "*h))
261
+ for w in range(4,8)
262
+ for h in range(3,6)
263
+ ]
264
+
265
+ compositions = [SupervisedTower("%dx%d-bridge on top of %dx%d bricks"%(b1,b2,w1,w2),
266
+ """
267
+ ((for j %d
268
+ (embed (for i %d h (r 6)))
269
+ (embed (r 3) (for i %d h (r 6))))
270
+ (r 1)
271
+ (for j %d
272
+ (for i %d
273
+ v (r 4) v (l 4)) (r 2) h
274
+ (r 4)))
275
+ """%(w1,w2,w2,b1,b2))
276
+ for b1,b2,w1,w2 in [(5,2,4,5)]
277
+ ] + [
278
+ SupervisedTower("%d pyramid on top of %dx%d bricks"%(p,w1,w2),
279
+ """
280
+ ((for j %d
281
+ (embed (for i %d h (r 6)))
282
+ (embed (r 3) (for i %d h (r 6))))
283
+ (r 1)
284
+ (for i %d (for j i (embed v (r 4) v (l 2) h)) (r 6))
285
+ (for i %d (for j (- %d i) (embed v (r 4) v (l 2) h)) (r 6)))
286
+ """%(w1,w2,w2,p,p,p))
287
+ for w1,w2,p in [(2,5,2)]
288
+ ] + \
289
+ [
290
+ SupervisedTower("%d tower on top of %dx%d bricks"%(t,w1,w2),
291
+ """
292
+ ((for j %d
293
+ (embed (for i %d h (r 6)))
294
+ (embed (r 3) (for i %d h (r 6))))
295
+ (r 6)
296
+ %s (r 4) %s (l 2) h)
297
+ """%(w1,w2,w2,
298
+ "v "*t, "v "*t))
299
+ for t,w1,w2 in [(4,1,3)] ]
300
+
301
+
302
+
303
+ everything = arches + simpleLoops + Bridges + archesStacks + aqueducts + offsetArches + pyramids + bricks + staircase2 + staircase1 + compositions
304
+ if False:
305
+ for t in everything:
306
+ delattr(t,'original')
307
+ return everything
308
+
309
+ def makeOldSupervisedTasks():
310
+ arches = [SupervisedTower("arch leg %d"%n,
311
+ "((for i %d v) (r 4) (for i %d v) (l 2) h)"%(n,n))
312
+ for n in range(1,9)
313
+ ]
314
+ archesStacks = [SupervisedTower("arch stack %d"%n,
315
+ """
316
+ (for i %d
317
+ v (r 4) v (l 2) h (l 2))
318
+ """%n)
319
+ for n in range(3,7) ]
320
+ Bridges = [SupervisedTower("bridge (%d) of arch %d"%(n,l),
321
+ """
322
+ (for j %d
323
+ (for i %d
324
+ v (r 4) v (l 4)) (r 2) h
325
+ (r 4))
326
+ """%(n,l))
327
+ for n in range(2,8)
328
+ for l in range(1,6)]
329
+ offsetArches = [SupervisedTower("bridge (%d) of arch, spaced %d"%(n,l),
330
+ """
331
+ (for j %d
332
+ v (r 4) v (l 2) h
333
+ (r %d))
334
+ """%(n,l))
335
+ for n,l in [(3,7),(4,6)]]
336
+ Josh = [SupervisedTower("Josh (%d)"%n,
337
+ """(for i %d
338
+ h (l 2) v (r 2) v (r 2) v (l 2) h (r 6))"""%n)
339
+ for n in range(1,7) ]
340
+
341
+ staircase1 = [SupervisedTower("R staircase %d"%n,
342
+ """
343
+ (for i %d (for j i
344
+ (embed v (r 4) v (l 2) h)) (r 6))
345
+ """%(n))
346
+ for n in range(3,8) ]
347
+ staircase2 = [SupervisedTower("L staircase %d"%n,
348
+ """
349
+ (for i %d (for j i
350
+ (embed v (r 4) v (l 2) h)) (l 6))
351
+ """%(n))
352
+ for n in range(3,8) ]
353
+ simpleLoops = [SupervisedTower("horizontal row %d, spacing %d"%(n,s),
354
+ """(for j %d h (r %s))"""%(n,s))
355
+ for n,s in [(4,6),(5,7)] ]+\
356
+ [SupervisedTower("horizontal stack %d"%n,
357
+ """(for j %d h)"""%n)
358
+ for n in range(5,8) ]+\
359
+ [SupervisedTower("vertical stack %d"%n,
360
+ """(for j %d v)"""%n)
361
+ for n in [5,7] ]
362
+ pyramids = []
363
+ pyramids += [SupervisedTower("arch pyramid %d"%n,
364
+ """((for i %d (for j i (embed v (r 4) v (l 2) h)) (r 6))
365
+ (for i %d (for j (- %d i) (embed v (r 4) v (l 2) h)) (r 6)))"""%(n,n,n))
366
+ for n in range(2,6) ]
367
+ pyramids += [SupervisedTower("H pyramid %d"%n,
368
+ """((for i %d (for j i h) (r 6))
369
+ (for i %d (for j (- %d i) h) (r 6)))"""%(n,n,n))
370
+ for n in range(4,6) ]
371
+ # pyramids += [SupervisedTower("V pyramid %d"%n,
372
+ # """
373
+ # ((for i %d (for j i v) (r 2))
374
+ # (for i %d (for j (- %d i) v) (r 2)))
375
+ # """%(n,n,n))
376
+ # for n in range(4,8) ]
377
+ # pyramids += [SupervisedTower("V3 pyramid %d"%n,
378
+ # """
379
+ # ((for i %d (for j i v) (r 6))
380
+ # (for i %d (for j (- %d i) v) (r 6)))
381
+ # """%(n,n,n))
382
+ # for n in range(4,8) ]
383
+ pyramids += [SupervisedTower("H 1/2 pyramid %d"%n,
384
+ """
385
+ (for i %d
386
+ (r 6)
387
+ (embed
388
+ (for j i h (l 3))))
389
+ """%n)
390
+ for n in range(4,8) ]
391
+ pyramids += [SupervisedTower("arch 1/2 pyramid %d"%n,
392
+ """
393
+ (for i %d
394
+ (r 6)
395
+ (embed
396
+ (for j i (embed v (r 4) v (l 2) h) (l 3))))
397
+ """%n)
398
+ for n in range(2,8) ]
399
+ if False:
400
+ pyramids += [SupervisedTower("V 1/2 pyramid %d"%n,
401
+ """
402
+ (for i %d
403
+ (r 2)
404
+ (embed
405
+ (for j i v (l 1))))"""%(n))
406
+ for n in range(4,8) ]
407
+ bricks = [SupervisedTower("brickwall, %dx%d"%(w,h),
408
+ """(for j %d
409
+ (embed (for i %d h (r 6)))
410
+ (embed (r 3) (for i %d h (r 6))))"""%(h,w,w))
411
+ for w in range(3,7)
412
+ for h in range(1,6) ]
413
+ aqueducts = [SupervisedTower("aqueduct: %dx%d"%(w,h),
414
+ """(for j %d
415
+ %s (r 4) %s (l 2) h (l 2) v (r 4) v (l 2) h (r 4))"""%
416
+ (w, "v "*h, "v "*h))
417
+ for w in range(4,8)
418
+ for h in range(3,6)
419
+ ]
420
+
421
+ compositions = [SupervisedTower("%dx%d-bridge on top of %dx%d bricks"%(b1,b2,w1,w2),
422
+ """
423
+ ((for j %d
424
+ (embed (for i %d h (r 6)))
425
+ (embed (r 3) (for i %d h (r 6))))
426
+ (r 1)
427
+ (for j %d
428
+ (for i %d
429
+ v (r 4) v (l 4)) (r 2) h
430
+ (r 4)))
431
+ """%(w1,w2,w2,b1,b2))
432
+ for b1,b2,w1,w2 in [(5,2,4,5)]
433
+ ] + [
434
+ SupervisedTower("%d pyramid on top of %dx%d bricks"%(p,w1,w2),
435
+ """
436
+ ((for j %d
437
+ (embed (for i %d h (r 6)))
438
+ (embed (r 3) (for i %d h (r 6))))
439
+ (r 1)
440
+ (for i %d (for j i (embed v (r 4) v (l 2) h)) (r 6))
441
+ (for i %d (for j (- %d i) (embed v (r 4) v (l 2) h)) (r 6)))
442
+ """%(w1,w2,w2,p,p,p))
443
+ for w1,w2,p in [(2,5,2)]
444
+ ] + \
445
+ [
446
+ SupervisedTower("%d tower on top of %dx%d bricks"%(t,w1,w2),
447
+ """
448
+ ((for j %d
449
+ (embed (for i %d h (r 6)))
450
+ (embed (r 3) (for i %d h (r 6))))
451
+ (r 6)
452
+ %s (r 4) %s (l 2) h)
453
+ """%(w1,w2,w2,
454
+ "v "*t, "v "*t))
455
+ for t,w1,w2 in [(4,1,3)] ]
456
+
457
+
458
+
459
+ everything = arches + simpleLoops + Bridges + archesStacks + aqueducts + offsetArches + pyramids + bricks + staircase2 + staircase1 + compositions
460
+ if False:
461
+ for t in everything:
462
+ delattr(t,'original')
463
+ return everything
464
+
465
+ def dSLDemo():
466
+ DSL = {}
467
+ bricks = Program.parse("(lambda (lambda (tower_loopM $0 (lambda (lambda (moveHand 3 (reverseHand (tower_loopM $3 (lambda (lambda (moveHand 6 (3x1 $0)))) $0))))))))")
468
+ DSL["bricks"] = [ [bricks.runWithArguments([x,y + 4,_empty_tower,TowerState()])[1]
469
+ for y in range(6, 6 + 3*4, 3) ]
470
+ for x in [3,8] ]
471
+ dimensionality = {}
472
+ dimensionality["bricks"] = 2
473
+
474
+ bridge = Program.parse("(lambda (lambda (tower_loopM $0 (lambda (lambda (#(lambda (#(lambda (lambda (lambda (tower_loopM $0 (lambda (lambda (1x3 (moveHand 4 ($3 $0))))) (moveHand 2 (3x1 $2)))))) $0 (lambda (reverseHand $0)))) (moveHand 4 $0) $3))))))")
475
+ DSL["bridge"] = [ [bridge.runWithArguments([x,y,_empty_tower,TowerState()])[1]
476
+ for x in range(4,4 + 2*4,2) ]
477
+ for y in [4,9] ]
478
+ dimensionality["bridge"] = 2
479
+
480
+ staircase = Program.parse("(lambda (tower_loopM $0 (lambda (lambda (#(lambda (lambda (tower_loopM $1 (lambda (lambda (tower_embed (lambda (#(lambda (1x3 (moveHand 4 (1x3 (reverseHand (moveHand 2 (3x1 $0))))))) $0)) $0))) $0))) $1 (moveHand 6 $0))))))")
481
+ DSL["staircase"] = [ staircase.runWithArguments([n,_empty_tower,TowerState()])[1]
482
+ for n in range(4,5 + 3) ]
483
+
484
+ pyramid = Program.parse("(lambda (tower_loopM $0 (lambda (lambda (moveHand 6 (tower_embed (lambda (reverseHand ((lambda (lambda (tower_loopM $1 (lambda (lambda (moveHand $2 (1x3 (moveHand 2 (tower_embed (lambda (moveHand 2 (1x3 $0))) (3x1 $0)))))))))) $2 1 $0))) $0))))))")
485
+ DSL["pyramid"] = [ pyramid.runWithArguments([n,_empty_tower,TowerState()])[1]
486
+ for n in range(4,5 + 3) ]
487
+
488
+ towerArch = Program.parse("(lambda (lambda ((lambda ((lambda (lambda (lambda (tower_loopM $0 (lambda (lambda (1x3 (moveHand 4 ($3 $0))))) (moveHand 2 (3x1 $2)))))) $0 (lambda (reverseHand (1x3 $0))))) $0 $1)))")
489
+ DSL["towerArch"] = [ towerArch.runWithArguments([n,_empty_tower,TowerState()])[1]
490
+ for n in range(4,5 + 3) ]
491
+
492
+ images = {}
493
+ for k,v in DSL.items():
494
+ d = dimensionality.get(k,1)
495
+ if d == 1:
496
+ i = montageMatrix([[renderPlan(p, pretty=True, Lego=True) for p in v]])
497
+ elif d == 2:
498
+ i = montageMatrix([[renderPlan(p, pretty=True, Lego=True) for p in ps] for ps in v] )
499
+ else: assert False
500
+
501
+ images[k] = i
502
+
503
+ return images
504
+
505
+ if __name__ == "__main__":
506
+ from pylab import imshow,show
507
+ from dreamcoder.domains.tower.tower_common import *
508
+
509
+ ts = makeSupervisedTasks()
510
+ print(len(ts),"total tasks")
511
+ print("maximum plan length",max(len(f.plan) for f in ts ))
512
+ print("maximum tower length",max(towerLength(f.plan) for f in ts ))
513
+ print("maximum tower height",max(towerHeight(simulateWithoutPhysics(f.plan)) for f in ts ))
514
+ SupervisedTower.exportMany("/tmp/every_tower.png",ts,shuffle=False)
515
+
516
+ for j,t in enumerate(ts):
517
+ t.exportImage("/tmp/tower_%d.png"%j,
518
+ drawHand=False)
519
+
520
+ for k,v in dSLDemo().items():
521
+ import scipy.misc
522
+ scipy.misc.imsave(f"/tmp/tower_dsl_{k}.png", v)
523
+
524
+ exampleTowers = [103,104,105,93,73,
525
+ 50,67,35,43,106]
526
+ SupervisedTower.exportMany("/tmp/tower_montage.png",
527
+ [ts[n] for n in exampleTowers ],
528
+ columns=5,
529
+ shuffle=False)
530
+ assert False
531
+
532
+
533
+ keywords = ["pyramid",
534
+ "on top of",
535
+ "arch 1/2 pyramid",
536
+ "brickwall",
537
+ "staircase",
538
+ "bridge",
539
+ "aqueduct",
540
+ "spaced",
541
+ "spaced",
542
+ "arch stack"]
543
+ for n in range(100):
544
+ examples = []
545
+ for kw in keywords:
546
+ if kw == "on top of":
547
+ examples = examples + list(filter(lambda t: kw in str(t), ts))
548
+ else:
549
+ examples.append(random.choice(list(filter(lambda t: kw in str(t), ts))))
550
+
551
+ random.shuffle(examples)
552
+ SupervisedTower.exportMany("/tmp/tower10_%d.png"%n,examples,
553
+ columns=int(len(examples)/2))
554
+
555
+
556
+
dreamcoder/domains/tower/towerPrimitives.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.program import *
2
+
3
+
4
+ class TowerState:
5
+ def __init__(self, hand=0, orientation=1, history=None):
6
+ # List of (State|Block)
7
+ self.history = history
8
+ self.hand = hand
9
+ self.orientation = orientation
10
+ def __str__(self): return f"S(h={self.hand},o={self.orientation})"
11
+ def __repr__(self): return str(self)
12
+ def left(self, n):
13
+ return TowerState(hand=self.hand - n, orientation=self.orientation,
14
+ history=self.history if self.history is None \
15
+ else self.history + [self])
16
+ def right(self, n): return TowerState(hand=self.hand + n, orientation=self.orientation,
17
+ history=self.history if self.history is None \
18
+ else self.history + [self])
19
+ def reverse(self): return TowerState(hand=self.hand, orientation=-1*self.orientation,
20
+ history=self.history if self.history is None \
21
+ else self.history + [self])
22
+ def move(self, n): return TowerState(hand=self.hand + n*self.orientation, orientation=self.orientation,
23
+ history=self.history if self.history is None \
24
+ else self.history + [self])
25
+
26
+ def recordBlock(self, b):
27
+ if self.history is None: return self
28
+ return TowerState(hand=self.hand,
29
+ orientation=self.orientation,
30
+ history=self.history + [b])
31
+
32
+
33
+ def _empty_tower(h): return (h,[])
34
+ def _left(d):
35
+ return lambda k: lambda s: k(s.left(d))
36
+ def _right(d):
37
+ return lambda k: lambda s: k(s.right(d))
38
+ def _loop(n):
39
+ def f(start, stop, body, state):
40
+ if start >= stop: return state,[]
41
+ state, thisIteration = body(start)(state)
42
+ state, laterIterations = f(start + 1, stop, body, state)
43
+ return state, thisIteration + laterIterations
44
+ def sequence(b,k,h):
45
+ h,bodyBlocks = f(0,n,b,h)
46
+ h,laterBlocks = k(h)
47
+ return h,bodyBlocks+laterBlocks
48
+ return lambda b: lambda k: lambda h: sequence(b,k,h)
49
+ def _simpleLoop(n):
50
+ def f(start, body, k):
51
+ if start >= n: return k
52
+ return body(start)(f(start + 1, body, k))
53
+ return lambda b: lambda k: f(0,b,k)
54
+ def _embed(body):
55
+ def f(k):
56
+ def g(hand):
57
+ bodyHand, bodyActions = body(_empty_tower)(hand)
58
+ # Record history if we are doing that
59
+ if hand.history is not None:
60
+ hand = TowerState(hand=hand.hand,
61
+ orientation=hand.orientation,
62
+ history=bodyHand.history)
63
+ hand, laterActions = k(hand)
64
+ return hand, bodyActions + laterActions
65
+ return g
66
+ return f
67
+ def _moveHand(n):
68
+ return lambda k: lambda s: k(s.move(n))
69
+ def _reverseHand(k):
70
+ return lambda s: k(s.reverse())
71
+
72
+ class TowerContinuation(object):
73
+ def __init__(self, x, w, h):
74
+ self.x = x
75
+ self.w = w*2
76
+ self.h = h*2
77
+ def __call__(self, k):
78
+ def f(hand):
79
+ thisAction = [(self.x + hand.hand,self.w,self.h)]
80
+ hand = hand.recordBlock(thisAction[0])
81
+ hand, rest = k(hand)
82
+ return hand, thisAction + rest
83
+ return f
84
+
85
+ # name, dimensions
86
+ blocks = {
87
+ # "1x1": (1.,1.),
88
+ # "2x1": (2.,1.),
89
+ # "1x2": (1.,2.),
90
+ "3x1": (3, 1),
91
+ "1x3": (1, 3),
92
+ # "4x1": (4.,1.),
93
+ # "1x4": (1.,4.)
94
+ }
95
+
96
+
97
+ ttower = baseType("tower")
98
+ common_primitives = [
99
+ Primitive("tower_loopM", arrow(tint, arrow(tint, ttower, ttower), ttower, ttower), _simpleLoop),
100
+ Primitive("tower_embed", arrow(arrow(ttower,ttower), ttower, ttower), _embed),
101
+ ] + [Primitive(name, arrow(ttower,ttower), TowerContinuation(0, w, h))
102
+ for name, (w, h) in blocks.items()] + \
103
+ [Primitive(str(j), tint, j) for j in range(1,9) ]
104
+ primitives = common_primitives + [
105
+ Primitive("left", arrow(tint, ttower, ttower), _left),
106
+ Primitive("right", arrow(tint, ttower, ttower), _right)
107
+ ]
108
+
109
+ new_primitives = common_primitives + [
110
+ Primitive("moveHand", arrow(tint, ttower, ttower), _moveHand),
111
+ Primitive("reverseHand", arrow(ttower, ttower), _reverseHand)
112
+ ]
113
+
114
+ def executeTower(p, timeout=None):
115
+ try:
116
+ return runWithTimeout(lambda : p.evaluate([])(_empty_tower)(TowerState())[1],
117
+ timeout=timeout)
118
+ except RunWithTimeout: return None
119
+ except: return None
120
+
121
+ def animateTower(exportPrefix, p):
122
+ print(exportPrefix, p)
123
+ from dreamcoder.domains.tower.tower_common import renderPlan
124
+ state,actions = p.evaluate([])(_empty_tower)(TowerState(history=[]))
125
+ print(actions)
126
+ trajectory = state.history + [state]
127
+ print(trajectory)
128
+ print()
129
+
130
+ assert tuple(z for z in trajectory if not isinstance(z, TowerState) ) == tuple(actions)
131
+
132
+ def hd(n):
133
+ h = 0
134
+ for state in trajectory[:n]:
135
+ if isinstance(state, TowerState):
136
+ h = state.hand
137
+ return h
138
+ animation = [renderPlan([b for b in trajectory[:n] if not isinstance(b, TowerState)],
139
+ pretty=True, Lego=True,
140
+ drawHand=hd(n),
141
+ masterPlan=actions,
142
+ randomSeed=hash(exportPrefix))
143
+ for n in range(0,len(trajectory) + 1)]
144
+ import scipy.misc
145
+ import random
146
+ r = random.random()
147
+ paths = []
148
+ for n in range(len(animation)):
149
+ paths.append(f"{exportPrefix}_{n}.png")
150
+ scipy.misc.imsave(paths[-1], animation[n])
151
+ os.system(f"convert -delay 10 -loop 0 {' '.join(paths)} {exportPrefix}.gif")
152
+ # os.system(f"rm {' '.join(paths)}")
dreamcoder/domains/tower/tower_common.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+ from dreamcoder.utilities import *
4
+
5
+ def simulateWithoutPhysics(plan,ordered=True):
6
+ def overlap(b1,
7
+ b2):
8
+ (x,w,h) = b1
9
+ (x_,y_,w_,h_) = b2
10
+ x1 = x - w/2
11
+ x2 = x + w/2
12
+ x1_ = x_ - w_/2
13
+ x2_ = x_ + w_/2
14
+ if x1_ >= x2 or x1 >= x2_: return None
15
+ assert h%2 == 0 and h_%2 == 0
16
+ return y_ + h_//2 + h//2
17
+ def lowestPossibleHeight(b):
18
+ h = b[2]
19
+ assert h%2 == 0
20
+ return int(h/2)
21
+ def placeAtHeight(b,y):
22
+ (x,w,h) = b
23
+ return (x,y,w,h)
24
+ def placeBlock(world, block):
25
+ lowest = max([lowestPossibleHeight(block)] + \
26
+ [overlap(block,other)
27
+ for other in world
28
+ if overlap(block,other) is not None])
29
+ world.append(placeAtHeight(block, lowest))
30
+
31
+ w = []
32
+ for p in plan: placeBlock(w,p)
33
+ if ordered: w = list(sorted(w))
34
+ return w
35
+
36
+ def centerTower(t,hand=None, masterPlan=None):
37
+
38
+ if len(t) == 0:
39
+ if hand is None:
40
+ return t
41
+ else:
42
+ return t, hand
43
+ def getCenter(t):
44
+ x1 = max(x for x, _, _ in t)
45
+ x0 = min(x for x, _, _ in t)
46
+ c = int((x1 - x0) / 2.0) + x0
47
+ return c
48
+ c = getCenter(masterPlan or t)
49
+ t = [(x - c, w, h) for x, w, h in t]
50
+ if hand is None:
51
+ return t
52
+ else:
53
+ return t, hand - c
54
+
55
+ def towerLength(t):
56
+ if len(t) == 0: return 0
57
+ x1 = max(x for x, _, _ in t)
58
+ x0 = min(x for x, _, _ in t)
59
+ return x1 - x0
60
+
61
+ def towerHeight(t):
62
+ y1 = max(y + h/2 for _, y, _, h in t )
63
+ y0 = min(y - h/2 for _, y, _, h in t )
64
+ return y1 - y0
65
+
66
+
67
+
68
+ def renderPlan(plan, resolution=256, window=64, floorHeight=2, borderSize=1, bodyColor=(0.,1.,1.),
69
+ borderColor=(1.,0.,0.),
70
+ truncate=None, randomSeed=None,
71
+ masterPlan=None,
72
+ pretty=False, Lego=False,
73
+ drawHand=None):
74
+ import numpy as np
75
+
76
+ if Lego: assert pretty
77
+
78
+ if drawHand is not None and drawHand is not False:
79
+ plan, drawHand = centerTower(plan, drawHand,
80
+ masterPlan=masterPlan)
81
+ else:
82
+ plan = centerTower(plan,masterPlan=masterPlan)
83
+
84
+ world = simulateWithoutPhysics(plan,
85
+ ordered=randomSeed is None)
86
+ if truncate is not None: world = world[:truncate]
87
+ a = np.zeros((resolution, resolution, 3))
88
+
89
+ def transform(x,y):
90
+ y = resolution - y*resolution/float(window)
91
+ x = resolution/2 + x*resolution/float(window)
92
+ return int(x + 0.5),int(y + 0.5)
93
+ def clip(p):
94
+ if p < 0: return 0
95
+ if p >= resolution: return resolution - 1
96
+ return int(p + 0.5)
97
+ def clear(x,y):
98
+ for xp,yp,wp,hp in world:
99
+ if x < xp + wp/2. and \
100
+ x > xp - wp/2. and \
101
+ y < yp + hp/2. and \
102
+ y > yp - hp/2.:
103
+ return False
104
+ return True
105
+ def bump(x,y,c):
106
+ size = 0.5*resolution/window
107
+ x,y = transform(x,y)
108
+ y -= floorHeight
109
+ y1 = y
110
+ y2 = y - size
111
+ x1 = x - size/2
112
+ x2 = x + size/2
113
+ a[clip(y2) : clip(y1),
114
+ clip(x1) : clip(x2),
115
+ :] = c
116
+
117
+
118
+ if randomSeed is not None:
119
+ randomNumbers = random.Random(randomSeed)
120
+ def _color():
121
+ if randomSeed is None:
122
+ return random.random()*0.7 + 0.3
123
+ else:
124
+ return randomNumbers.random()*0.7 + 0.3
125
+ def color():
126
+ return (_color(),_color(),_color())
127
+
128
+ def rectangle(x1,x2,y1,y2,c,cp=None):
129
+ x1,y1 = transform(x1,y1)
130
+ x2,y2 = transform(x2,y2)
131
+ y1 -= floorHeight
132
+ y2 -= floorHeight
133
+ a[clip(y2) : clip(y1),
134
+ clip(x1) : clip(x2),
135
+ :] = c
136
+ if cp is not None:
137
+ a[clip(y2 + borderSize) : clip(y1 - borderSize),
138
+ clip(x1 + borderSize) : clip(x2 - borderSize),
139
+ :] = cp
140
+
141
+ for x,y,w,h in world:
142
+ x1,y1 = x - w/2., y - h/2.
143
+ x2,y2 = x + w/2., y + h/2.
144
+ if pretty:
145
+ thisColor = color()
146
+ rectangle(x1,x2,y1,y2,
147
+ thisColor)
148
+ if Lego:
149
+ bumps = w
150
+ for nb in range(bumps):
151
+ nx = x - w/2. + 0.5 + nb
152
+ ny = y + h/2. + 0.00001
153
+ if clear(nx,ny):
154
+ bump(nx,ny,thisColor)
155
+ else:
156
+ rectangle(x1,x2,y1,y2,
157
+ borderColor, bodyColor)
158
+
159
+ a[resolution - floorHeight:,:,:] = 1.
160
+ if drawHand is not None:
161
+ if not Lego:
162
+ dh = 0.25
163
+ rectangle(drawHand - dh,
164
+ drawHand + dh,
165
+ -99999, 99999,
166
+ (0,1,0))
167
+ else:
168
+ rectangle(drawHand - 1,drawHand + 1,
169
+ 43,45,(1,1,1))
170
+
171
+ return a
172
+
173
+
dreamcoder/dreamcoder.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+
3
+ import dill
4
+
5
+ from dreamcoder.compression import induceGrammar
6
+ from dreamcoder.utilities import *
7
+ try:
8
+ from dreamcoder.recognition import *
9
+ except:
10
+ eprint("Failure loading recognition - only acceptable if using pypy ")
11
+ from dreamcoder.enumeration import *
12
+ from dreamcoder.fragmentGrammar import *
13
+ from dreamcoder.taskBatcher import *
14
+ from dreamcoder.primitiveGraph import graphPrimitives
15
+ from dreamcoder.dreaming import backgroundHelmholtzEnumeration
16
+
17
+
18
+ class ECResult():
19
+ def __init__(self, _=None,
20
+ frontiersOverTime=None,
21
+ testingSearchTime=None,
22
+ learningCurve=None,
23
+ grammars=None,
24
+ taskSolutions=None,
25
+ averageDescriptionLength=None,
26
+ parameters=None,
27
+ recognitionModel=None,
28
+ searchTimes=None,
29
+ recognitionTaskMetrics=None,
30
+ numTestingTasks=None,
31
+ sumMaxll=None,
32
+ testingSumMaxll=None,
33
+ hitsAtEachWake=None,
34
+ timesAtEachWake=None,
35
+ allFrontiers=None):
36
+ self.frontiersOverTime = {} # Map from task to [frontier at iteration 1, frontier at iteration 2, ...]
37
+ self.hitsAtEachWake = hitsAtEachWake or []
38
+ self.timesAtEachWake = timesAtEachWake or []
39
+ self.testingSearchTime = testingSearchTime or []
40
+ self.searchTimes = searchTimes or []
41
+ self.trainSearchTime = {} # map from task to search time
42
+ self.testSearchTime = {} # map from task to search time
43
+ self.recognitionTaskMetrics = recognitionTaskMetrics or {}
44
+ self.recognitionModel = recognitionModel
45
+ self.averageDescriptionLength = averageDescriptionLength or []
46
+ self.parameters = parameters
47
+ self.learningCurve = learningCurve or []
48
+ self.grammars = grammars or []
49
+ self.taskSolutions = taskSolutions or {}
50
+ self.numTestingTasks = numTestingTasks
51
+ self.sumMaxll = sumMaxll or [] #TODO name change
52
+ self.testingSumMaxll = testingSumMaxll or [] #TODO name change
53
+ self.allFrontiers = allFrontiers or {}
54
+
55
+ def __repr__(self):
56
+ attrs = ["{}={}".format(k, v) for k, v in self.__dict__.items()]
57
+ return "ECResult({})".format(", ".join(attrs))
58
+
59
+ def getTestingTasks(self):
60
+ testing = []
61
+ training = self.taskSolutions.keys()
62
+ for t in self.recognitionTaskMetrics:
63
+ if isinstance(t, Task) and t not in training: testing.append(t)
64
+ return testing
65
+
66
+
67
+ def recordFrontier(self, frontier):
68
+ t = frontier.task
69
+ if t not in self.frontiersOverTime: self.frontiersOverTime[t] = []
70
+ self.frontiersOverTime[t].append(frontier)
71
+
72
+ # Linux does not like files that have more than 256 characters
73
+ # So when exporting the results we abbreviate the parameters
74
+ abbreviations = {"frontierSize": "fs",
75
+ "useDSL": "DSL",
76
+ "taskReranker": "TRR",
77
+ "matrixRank": "MR",
78
+ "reuseRecognition": "RR",
79
+ "ensembleSize": "ES",
80
+ "recognitionTimeout": "RT",
81
+ "recognitionSteps": "RS",
82
+ "iterations": "it",
83
+ "maximumFrontier": "MF",
84
+ "pseudoCounts": "pc",
85
+ "auxiliaryLoss": "aux",
86
+ "structurePenalty": "L",
87
+ "helmholtzRatio": "HR",
88
+ "biasOptimal": "BO",
89
+ "contextual": "CO",
90
+ "topK": "K",
91
+ "enumerationTimeout": "ET",
92
+ "useRecognitionModel": "rec",
93
+ "use_ll_cutoff": "llcut",
94
+ "topk_use_only_likelihood": "topkNotMAP",
95
+ "activation": "act",
96
+ "storeTaskMetrics": 'STM',
97
+ "topkNotMAP": "tknm",
98
+ "rewriteTaskMetrics": "RW",
99
+ 'taskBatchSize': 'batch'}
100
+
101
+ @staticmethod
102
+ def abbreviate(parameter): return ECResult.abbreviations.get(parameter, parameter)
103
+
104
+ @staticmethod
105
+ def parameterOfAbbreviation(abbreviation):
106
+ return ECResult.abbreviationToParameter.get(abbreviation, abbreviation)
107
+
108
+ @staticmethod
109
+ def clearRecognitionModel(path):
110
+ SUFFIX = '.pickle'
111
+ assert path.endswith(SUFFIX)
112
+
113
+ with open(path,'rb') as handle:
114
+ result = dill.load(handle)
115
+
116
+ result.recognitionModel = None
117
+
118
+ clearedPath = path[:-len(SUFFIX)] + "_graph=True" + SUFFIX
119
+ with open(clearedPath,'wb') as handle:
120
+ result = dill.dump(result, handle)
121
+ eprint(" [+] Cleared recognition model from:")
122
+ eprint(" %s"%path)
123
+ eprint(" and exported to:")
124
+ eprint(" %s"%clearedPath)
125
+ eprint(" Use this one for graphing.")
126
+
127
+
128
+ ECResult.abbreviationToParameter = {
129
+ v: k for k, v in ECResult.abbreviations.items()}
130
+
131
+
132
+ def explorationCompression(*arguments, **keywords):
133
+ for r in ecIterator(*arguments, **keywords):
134
+ pass
135
+ return r
136
+
137
+
138
+ def ecIterator(grammar, tasks,
139
+ _=None,
140
+ useDSL=True,
141
+ noConsolidation=False,
142
+ mask=False,
143
+ seed=0,
144
+ addFullTaskMetrics=False,
145
+ matrixRank=None,
146
+ solver='ocaml',
147
+ compressor="rust",
148
+ biasOptimal=False,
149
+ contextual=False,
150
+ testingTasks=[],
151
+ iterations=None,
152
+ resume=None,
153
+ enumerationTimeout=None,
154
+ testingTimeout=None,
155
+ testEvery=1,
156
+ reuseRecognition=False,
157
+ ensembleSize=1,
158
+ useRecognitionModel=True,
159
+ recognitionTimeout=None,
160
+ recognitionSteps=None,
161
+ helmholtzRatio=0.,
162
+ featureExtractor=None,
163
+ activation='relu',
164
+ topK=1,
165
+ topk_use_only_likelihood=False,
166
+ use_map_search_times=True,
167
+ maximumFrontier=None,
168
+ pseudoCounts=1.0, aic=1.0,
169
+ structurePenalty=0.001, arity=0,
170
+ evaluationTimeout=1.0, # seconds
171
+ taskBatchSize=None,
172
+ taskReranker='default',
173
+ CPUs=1,
174
+ cuda=False,
175
+ message="",
176
+ outputPrefix=None,
177
+ storeTaskMetrics=False,
178
+ rewriteTaskMetrics=True,
179
+ auxiliaryLoss=False,
180
+ custom_wake_generative=None):
181
+ if enumerationTimeout is None:
182
+ eprint(
183
+ "Please specify an enumeration timeout:",
184
+ "explorationCompression(..., enumerationTimeout = ..., ...)")
185
+ assert False
186
+ if iterations is None:
187
+ eprint(
188
+ "Please specify a iteration count: explorationCompression(..., iterations = ...)")
189
+ assert False
190
+ if useRecognitionModel and featureExtractor is None:
191
+ eprint("Warning: Recognition model needs feature extractor.",
192
+ "Ignoring recognition model.")
193
+ useRecognitionModel = False
194
+ if ensembleSize > 1 and not useRecognitionModel:
195
+ eprint("Warning: ensemble size requires using the recognition model, aborting.")
196
+ assert False
197
+ if biasOptimal and not useRecognitionModel:
198
+ eprint("Bias optimality only applies to recognition models, aborting.")
199
+ assert False
200
+ if contextual and not useRecognitionModel:
201
+ eprint("Contextual only applies to recognition models, aborting")
202
+ assert False
203
+ if reuseRecognition and not useRecognitionModel:
204
+ eprint("Reuse of recognition model weights at successive iteration only applies to recognition models, aborting")
205
+ assert False
206
+ if matrixRank is not None and not contextual:
207
+ eprint("Matrix rank only applies to contextual recognition models, aborting")
208
+ assert False
209
+ assert useDSL or useRecognitionModel, "You specified that you didn't want to use the DSL AND you don't want to use the recognition model. Figure out what you want to use."
210
+ if testingTimeout > 0 and len(testingTasks) == 0:
211
+ eprint("You specified a testingTimeout, but did not provide any held out testing tasks, aborting.")
212
+ assert False
213
+
214
+ # We save the parameters that were passed into EC
215
+ # This is for the purpose of exporting the results of the experiment
216
+ parameters = {
217
+ k: v for k,
218
+ v in locals().items() if k not in {
219
+ "tasks",
220
+ "use_map_search_times",
221
+ "seed",
222
+ "activation",
223
+ "grammar",
224
+ "cuda",
225
+ "_",
226
+ "testingTimeout",
227
+ "testEvery",
228
+ "message",
229
+ "CPUs",
230
+ "outputPrefix",
231
+ "resume",
232
+ "resumeFrontierSize",
233
+ "addFullTaskMetrics",
234
+ "featureExtractor",
235
+ "evaluationTimeout",
236
+ "testingTasks",
237
+ "compressor",
238
+ "custom_wake_generative"} and v is not None}
239
+ if not useRecognitionModel:
240
+ for k in {"helmholtzRatio", "recognitionTimeout", "biasOptimal", "mask",
241
+ "contextual", "matrixRank", "reuseRecognition", "auxiliaryLoss", "ensembleSize"}:
242
+ if k in parameters: del parameters[k]
243
+ else: del parameters["useRecognitionModel"];
244
+ if useRecognitionModel and not contextual:
245
+ if "matrixRank" in parameters:
246
+ del parameters["matrixRank"]
247
+ if "mask" in parameters:
248
+ del parameters["mask"]
249
+ if not mask and 'mask' in parameters: del parameters["mask"]
250
+ if not auxiliaryLoss and 'auxiliaryLoss' in parameters: del parameters['auxiliaryLoss']
251
+ if not useDSL:
252
+ for k in {"structurePenalty", "pseudoCounts", "aic"}:
253
+ del parameters[k]
254
+ else: del parameters["useDSL"]
255
+
256
+ # Uses `parameters` to construct the checkpoint path
257
+ def checkpointPath(iteration, extra=""):
258
+ parameters["iterations"] = iteration
259
+ kvs = [
260
+ "{}={}".format(
261
+ ECResult.abbreviate(k),
262
+ parameters[k]) for k in sorted(
263
+ parameters.keys())]
264
+ return "{}_{}{}.pickle".format(outputPrefix, "_".join(kvs), extra)
265
+
266
+ if message:
267
+ message = " (" + message + ")"
268
+ eprint("Running EC%s on %s @ %s with %d CPUs and parameters:" %
269
+ (message, os.uname()[1], datetime.datetime.now(), CPUs))
270
+ for k, v in parameters.items():
271
+ eprint("\t", k, " = ", v)
272
+ eprint("\t", "evaluationTimeout", " = ", evaluationTimeout)
273
+ eprint("\t", "cuda", " = ", cuda)
274
+ eprint()
275
+
276
+ if addFullTaskMetrics:
277
+ assert resume is not None, "--addFullTaskMetrics requires --resume"
278
+
279
+ def reportMemory():
280
+ eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
281
+
282
+ # Restore checkpoint
283
+ if resume is not None:
284
+ try:
285
+ resume = int(resume)
286
+ path = checkpointPath(resume)
287
+ except ValueError:
288
+ path = resume
289
+ with open(path, "rb") as handle:
290
+ result = dill.load(handle)
291
+ resume = len(result.grammars) - 1
292
+ eprint("Loaded checkpoint from", path)
293
+ grammar = result.grammars[-1] if result.grammars else grammar
294
+ else: # Start from scratch
295
+ #for graphing of testing tasks
296
+ numTestingTasks = len(testingTasks) if len(testingTasks) != 0 else None
297
+
298
+ result = ECResult(parameters=parameters,
299
+ grammars=[grammar],
300
+ taskSolutions={
301
+ t: Frontier([],
302
+ task=t) for t in tasks},
303
+ recognitionModel=None, numTestingTasks=numTestingTasks,
304
+ allFrontiers={
305
+ t: Frontier([],
306
+ task=t) for t in tasks})
307
+
308
+
309
+ # Set up the task batcher.
310
+ if taskReranker == 'default':
311
+ taskBatcher = DefaultTaskBatcher()
312
+ elif taskReranker == 'random':
313
+ taskBatcher = RandomTaskBatcher()
314
+ elif taskReranker == 'randomShuffle':
315
+ taskBatcher = RandomShuffleTaskBatcher(seed)
316
+ elif taskReranker == 'unsolved':
317
+ taskBatcher = UnsolvedTaskBatcher()
318
+ elif taskReranker == 'unsolvedEntropy':
319
+ taskBatcher = UnsolvedEntropyTaskBatcher()
320
+ elif taskReranker == 'unsolvedRandomEntropy':
321
+ taskBatcher = UnsolvedRandomEntropyTaskBatcher()
322
+ elif taskReranker == 'randomkNN':
323
+ taskBatcher = RandomkNNTaskBatcher()
324
+ elif taskReranker == 'randomLowEntropykNN':
325
+ taskBatcher = RandomLowEntropykNNTaskBatcher()
326
+ else:
327
+ eprint("Invalid task reranker: " + taskReranker + ", aborting.")
328
+ assert False
329
+
330
+ # Check if we are just updating the full task metrics
331
+ if addFullTaskMetrics:
332
+ if testingTimeout is not None and testingTimeout > enumerationTimeout:
333
+ enumerationTimeout = testingTimeout
334
+ if result.recognitionModel is not None:
335
+ _enumerator = lambda *args, **kw: result.recognitionModel.enumerateFrontiers(*args, **kw)
336
+ else: _enumerator = lambda *args, **kw: multicoreEnumeration(result.grammars[-1], *args, **kw)
337
+ enumerator = lambda *args, **kw: _enumerator(*args,
338
+ maximumFrontier=maximumFrontier,
339
+ CPUs=CPUs, evaluationTimeout=evaluationTimeout,
340
+ solver=solver,
341
+ **kw)
342
+ trainFrontiers, _, trainingTimes = enumerator(tasks, enumerationTimeout=enumerationTimeout)
343
+ testFrontiers, _, testingTimes = enumerator(testingTasks, enumerationTimeout=testingTimeout, testing=True)
344
+
345
+ recognizer = result.recognitionModel
346
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, trainingTimes, 'recognitionBestTimes')
347
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarLogProductions(tasks), 'taskLogProductions')
348
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarEntropies(tasks), 'taskGrammarEntropies')
349
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskAuxiliaryLossLayer(tasks), 'taskAuxiliaryLossLayer')
350
+
351
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, testingTimes, 'heldoutTestingTimes')
352
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarLogProductions(testingTasks), 'heldoutTaskLogProductions')
353
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarEntropies(testingTasks), 'heldoutTaskGrammarEntropies')
354
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskAuxiliaryLossLayer(testingTasks), 'heldoutAuxiliaryLossLayer')
355
+
356
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, {f.task: f
357
+ for f in trainFrontiers + testFrontiers
358
+ if len(f) > 0},
359
+ 'frontier')
360
+ SUFFIX = ".pickle"
361
+ assert path.endswith(SUFFIX)
362
+ path = path[:-len(SUFFIX)] + "_FTM=True" + SUFFIX
363
+ with open(path, "wb") as handle: dill.dump(result, handle)
364
+ if useRecognitionModel: ECResult.clearRecognitionModel(path)
365
+
366
+ sys.exit(0)
367
+
368
+
369
+ for j in range(resume or 0, iterations):
370
+ if storeTaskMetrics and rewriteTaskMetrics:
371
+ eprint("Resetting task metrics for next iteration.")
372
+ result.recognitionTaskMetrics = {}
373
+
374
+ reportMemory()
375
+
376
+ # Evaluate on held out tasks if we have them
377
+ if testingTimeout > 0 and ((j % testEvery == 0) or (j == iterations - 1)):
378
+ eprint("Evaluating on held out testing tasks for iteration: %d" % (j))
379
+ evaluateOnTestingTasks(result, testingTasks, grammar,
380
+ CPUs=CPUs, maximumFrontier=maximumFrontier,
381
+ solver=solver,
382
+ enumerationTimeout=testingTimeout, evaluationTimeout=evaluationTimeout)
383
+ # If we have to also enumerate Helmholtz frontiers,
384
+ # do this extra sneaky in the background
385
+ if useRecognitionModel and biasOptimal and helmholtzRatio > 0 and \
386
+ all( str(p) != "REAL" for p in grammar.primitives ): # real numbers don't support this
387
+ # the DSL is fixed, so the dreams are also fixed. don't recompute them.
388
+ if useDSL or 'helmholtzFrontiers' not in locals():
389
+ helmholtzFrontiers = backgroundHelmholtzEnumeration(tasks, grammar, enumerationTimeout,
390
+ evaluationTimeout=evaluationTimeout,
391
+ special=featureExtractor.special)
392
+ else:
393
+ print("Reusing dreams from previous iteration.")
394
+ else:
395
+ helmholtzFrontiers = lambda: []
396
+
397
+ reportMemory()
398
+
399
+ # Get waking task batch.
400
+ wakingTaskBatch = taskBatcher.getTaskBatch(result, tasks, taskBatchSize, j)
401
+ eprint("Using a waking task batch of size: " + str(len(wakingTaskBatch)))
402
+
403
+ # WAKING UP
404
+ if useDSL:
405
+ wake_generative = custom_wake_generative if custom_wake_generative is not None else default_wake_generative
406
+ topDownFrontiers, times = wake_generative(grammar, wakingTaskBatch,
407
+ solver=solver,
408
+ maximumFrontier=maximumFrontier,
409
+ enumerationTimeout=enumerationTimeout,
410
+ CPUs=CPUs,
411
+ evaluationTimeout=evaluationTimeout)
412
+ result.trainSearchTime = {t: tm for t, tm in times.items() if tm is not None}
413
+ else:
414
+ eprint("Skipping top-down enumeration because we are not using the generative model")
415
+ topDownFrontiers, times = [], {t: None for t in wakingTaskBatch }
416
+
417
+ tasksHitTopDown = {f.task for f in topDownFrontiers if not f.empty}
418
+ result.hitsAtEachWake.append(len(tasksHitTopDown))
419
+
420
+ reportMemory()
421
+
422
+ # Combine topDownFrontiers from this task batch with all frontiers.
423
+ for f in topDownFrontiers:
424
+ if f.task not in result.allFrontiers: continue # backward compatibility with old checkpoints
425
+ result.allFrontiers[f.task] = result.allFrontiers[f.task].combine(f).topK(maximumFrontier)
426
+
427
+ eprint("Frontiers discovered top down: " + str(len(tasksHitTopDown)))
428
+ eprint("Total frontiers: " + str(len([f for f in result.allFrontiers.values() if not f.empty])))
429
+
430
+ # Train + use recognition model
431
+ if useRecognitionModel:
432
+ # Should we initialize the weights to be what they were before?
433
+ previousRecognitionModel = None
434
+ if reuseRecognition and result.recognitionModel is not None:
435
+ previousRecognitionModel = result.recognitionModel
436
+
437
+ thisRatio = helmholtzRatio
438
+ #if j == 0 and not biasOptimal: thisRatio = 0
439
+ if all( f.empty for f in result.allFrontiers.values() ): thisRatio = 1.
440
+
441
+ tasksHitBottomUp = \
442
+ sleep_recognition(result, grammar, wakingTaskBatch, tasks, testingTasks, result.allFrontiers.values(),
443
+ ensembleSize=ensembleSize, featureExtractor=featureExtractor, mask=mask,
444
+ activation=activation, contextual=contextual, biasOptimal=biasOptimal,
445
+ previousRecognitionModel=previousRecognitionModel, matrixRank=matrixRank,
446
+ timeout=recognitionTimeout, evaluationTimeout=evaluationTimeout,
447
+ enumerationTimeout=enumerationTimeout,
448
+ helmholtzRatio=thisRatio, helmholtzFrontiers=helmholtzFrontiers(),
449
+ auxiliaryLoss=auxiliaryLoss, cuda=cuda, CPUs=CPUs, solver=solver,
450
+ recognitionSteps=recognitionSteps, maximumFrontier=maximumFrontier)
451
+
452
+ showHitMatrix(tasksHitTopDown, tasksHitBottomUp, wakingTaskBatch)
453
+
454
+ # Record the new topK solutions
455
+ result.taskSolutions = {f.task: f.topK(topK)
456
+ for f in result.allFrontiers.values()}
457
+ for f in result.allFrontiers.values(): result.recordFrontier(f)
458
+ result.learningCurve += [
459
+ sum(f is not None and not f.empty for f in result.taskSolutions.values())]
460
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, {f.task: f
461
+ for f in result.allFrontiers.values()
462
+ if len(f) > 0},
463
+ 'frontier')
464
+
465
+ # Sleep-G
466
+ if useDSL and not(noConsolidation):
467
+ eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
468
+ grammar = consolidate(result, grammar, topK=topK, pseudoCounts=pseudoCounts, arity=arity, aic=aic,
469
+ structurePenalty=structurePenalty, compressor=compressor, CPUs=CPUs,
470
+ iteration=j)
471
+ eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
472
+ else:
473
+ eprint("Skipping consolidation.")
474
+ result.grammars.append(grammar)
475
+
476
+ if outputPrefix is not None:
477
+ path = checkpointPath(j + 1)
478
+ with open(path, "wb") as handle:
479
+ try:
480
+ dill.dump(result, handle)
481
+ except TypeError as e:
482
+ eprint(result)
483
+ assert(False)
484
+ eprint("Exported checkpoint to", path)
485
+ if useRecognitionModel:
486
+ ECResult.clearRecognitionModel(path)
487
+
488
+ graphPrimitives(result, "%s_primitives_%d_"%(outputPrefix,j))
489
+
490
+
491
+ yield result
492
+
493
+
494
+ def showHitMatrix(top, bottom, tasks):
495
+ tasks = set(tasks)
496
+
497
+ total = bottom | top
498
+ eprint(len(total), "/", len(tasks), "total hit tasks")
499
+ bottomMiss = tasks - bottom
500
+ topMiss = tasks - top
501
+
502
+ eprint("{: <13s}{: ^13s}{: ^13s}".format("", "bottom miss", "bottom hit"))
503
+ eprint("{: <13s}{: ^13d}{: ^13d}".format("top miss",
504
+ len(bottomMiss & topMiss),
505
+ len(bottom & topMiss)))
506
+ eprint("{: <13s}{: ^13d}{: ^13d}".format("top hit",
507
+ len(top & bottomMiss),
508
+ len(top & bottom)))
509
+
510
+ def evaluateOnTestingTasks(result, testingTasks, grammar, _=None,
511
+ CPUs=None, solver=None, maximumFrontier=None, enumerationTimeout=None, evaluationTimeout=None):
512
+ if result.recognitionModel is not None:
513
+ recognizer = result.recognitionModel
514
+ testingFrontiers, times = \
515
+ recognizer.enumerateFrontiers(testingTasks,
516
+ CPUs=CPUs,
517
+ solver=solver,
518
+ maximumFrontier=maximumFrontier,
519
+ enumerationTimeout=enumerationTimeout,
520
+ evaluationTimeout=evaluationTimeout,
521
+ testing=True)
522
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarLogProductions(testingTasks), 'heldoutTaskLogProductions')
523
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarEntropies(testingTasks), 'heldoutTaskGrammarEntropies')
524
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, recognizer.taskGrammarEntropies(testingTasks), 'heldoutTaskGrammarEntropies')
525
+ else:
526
+ testingFrontiers, times = multicoreEnumeration(grammar, testingTasks,
527
+ solver=solver,
528
+ maximumFrontier=maximumFrontier,
529
+ enumerationTimeout=enumerationTimeout,
530
+ CPUs=CPUs,
531
+ evaluationTimeout=evaluationTimeout,
532
+ testing=True)
533
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, times, 'heldoutTestingTimes')
534
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics,
535
+ {f.task: f for f in testingFrontiers if len(f) > 0 },
536
+ 'frontier')
537
+ for f in testingFrontiers: result.recordFrontier(f)
538
+ result.testSearchTime = {t: tm for t, tm in times.items() if tm is not None}
539
+ times = [t for t in times.values() if t is not None ]
540
+ eprint("\n".join(f.summarize() for f in testingFrontiers))
541
+ summaryStatistics("Testing tasks", times)
542
+ eprint("Hits %d/%d testing tasks" % (len(times), len(testingTasks)))
543
+ result.testingSearchTime.append(times)
544
+
545
+
546
+ def default_wake_generative(grammar, tasks,
547
+ maximumFrontier=None,
548
+ enumerationTimeout=None,
549
+ CPUs=None,
550
+ solver=None,
551
+ evaluationTimeout=None):
552
+ topDownFrontiers, times = multicoreEnumeration(grammar, tasks,
553
+ maximumFrontier=maximumFrontier,
554
+ enumerationTimeout=enumerationTimeout,
555
+ CPUs=CPUs,
556
+ solver=solver,
557
+ evaluationTimeout=evaluationTimeout)
558
+ eprint("Generative model enumeration results:")
559
+ eprint(Frontier.describe(topDownFrontiers))
560
+ summaryStatistics("Generative model", [t for t in times.values() if t is not None])
561
+ return topDownFrontiers, times
562
+
563
+ def sleep_recognition(result, grammar, taskBatch, tasks, testingTasks, allFrontiers, _=None,
564
+ ensembleSize=1, featureExtractor=None, matrixRank=None, mask=False,
565
+ activation=None, contextual=True, biasOptimal=True,
566
+ previousRecognitionModel=None, recognitionSteps=None,
567
+ timeout=None, enumerationTimeout=None, evaluationTimeout=None,
568
+ helmholtzRatio=None, helmholtzFrontiers=None, maximumFrontier=None,
569
+ auxiliaryLoss=None, cuda=None, CPUs=None, solver=None):
570
+ eprint("Using an ensemble size of %d. Note that we will only store and test on the best recognition model." % ensembleSize)
571
+
572
+ featureExtractorObjects = [featureExtractor(tasks, testingTasks=testingTasks, cuda=cuda) for i in range(ensembleSize)]
573
+ recognizers = [RecognitionModel(featureExtractorObjects[i],
574
+ grammar,
575
+ mask=mask,
576
+ rank=matrixRank,
577
+ activation=activation,
578
+ cuda=cuda,
579
+ contextual=contextual,
580
+ previousRecognitionModel=previousRecognitionModel,
581
+ id=i) for i in range(ensembleSize)]
582
+ eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
583
+ trainedRecognizers = parallelMap(min(CPUs,len(recognizers)),
584
+ lambda recognizer: recognizer.train(allFrontiers,
585
+ biasOptimal=biasOptimal,
586
+ helmholtzFrontiers=helmholtzFrontiers,
587
+ CPUs=CPUs,
588
+ evaluationTimeout=evaluationTimeout,
589
+ timeout=timeout,
590
+ steps=recognitionSteps,
591
+ helmholtzRatio=helmholtzRatio,
592
+ auxLoss=auxiliaryLoss,
593
+ vectorized=True),
594
+ recognizers,
595
+ seedRandom=True)
596
+ eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
597
+ # Enumerate frontiers for each of the recognizers.
598
+ eprint("Trained an ensemble of %d recognition models, now enumerating." % len(trainedRecognizers))
599
+ ensembleFrontiers, ensembleTimes, ensembleRecognitionTimes = [], [], []
600
+ mostTasks = 0
601
+ bestRecognizer = None
602
+ totalTasksHitBottomUp = set()
603
+ for recIndex, recognizer in enumerate(trainedRecognizers):
604
+ eprint("Enumerating from recognizer %d of %d" % (recIndex, len(trainedRecognizers)))
605
+ bottomupFrontiers, allRecognitionTimes = \
606
+ recognizer.enumerateFrontiers(taskBatch,
607
+ CPUs=CPUs,
608
+ maximumFrontier=maximumFrontier,
609
+ enumerationTimeout=enumerationTimeout,
610
+ evaluationTimeout=evaluationTimeout,
611
+ solver=solver)
612
+ ensembleFrontiers.append(bottomupFrontiers)
613
+ ensembleTimes.append([t for t in allRecognitionTimes.values() if t is not None])
614
+ ensembleRecognitionTimes.append(allRecognitionTimes)
615
+
616
+ recognizerTasksHitBottomUp = {f.task for f in bottomupFrontiers if not f.empty}
617
+ totalTasksHitBottomUp.update(recognizerTasksHitBottomUp)
618
+ eprint("Recognizer %d solved %d/%d tasks; total tasks solved is now %d." % (recIndex, len(recognizerTasksHitBottomUp), len(tasks), len(totalTasksHitBottomUp)))
619
+ if len(recognizerTasksHitBottomUp) >= mostTasks:
620
+ # TODO (cathywong): could consider keeping the one that put the highest likelihood on the solved tasks.
621
+ bestRecognizer = recIndex
622
+
623
+ # Store the recognizer that discovers the most frontiers in the result.
624
+ eprint("Best recognizer: %d." % bestRecognizer)
625
+ result.recognitionModel = trainedRecognizers[bestRecognizer]
626
+ result.trainSearchTime = {tk: tm for tk, tm in ensembleRecognitionTimes[bestRecognizer].items()
627
+ if tm is not None}
628
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, ensembleRecognitionTimes[bestRecognizer], 'recognitionBestTimes')
629
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskHiddenStates(tasks), 'hiddenState')
630
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarLogProductions(tasks), 'taskLogProductions')
631
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarEntropies(tasks), 'taskGrammarEntropies')
632
+ if contextual:
633
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics,
634
+ result.recognitionModel.taskGrammarStartProductions(tasks),
635
+ 'startProductions')
636
+
637
+ result.hitsAtEachWake.append(len(totalTasksHitBottomUp))
638
+ eprint(f"Currently using this much memory: {getThisMemoryUsage()}")
639
+
640
+ """ Rescore and combine the frontiers across the ensemble of recognition models."""
641
+ eprint("Recognition model enumeration results for the best recognizer.")
642
+ eprint(Frontier.describe(ensembleFrontiers[bestRecognizer]))
643
+ summaryStatistics("Recognition model", ensembleTimes[bestRecognizer])
644
+
645
+ eprint("Cumulative results for the full ensemble of %d recognizers: " % len(trainedRecognizers))
646
+ # Rescore all of the ensemble frontiers according to the generative model
647
+ # and then combine w/ original frontiers
648
+ for bottomupFrontiers in ensembleFrontiers:
649
+ for b in bottomupFrontiers:
650
+ if b.task not in result.allFrontiers: continue # backwards compatibility with old checkpoints
651
+ result.allFrontiers[b.task] = result.allFrontiers[b.task].\
652
+ combine(grammar.rescoreFrontier(b)).\
653
+ topK(maximumFrontier)
654
+
655
+ eprint("Frontiers discovered bottom up: " + str(len(totalTasksHitBottomUp)))
656
+ eprint("Total frontiers: " + str(len([f for f in result.allFrontiers.values() if not f.empty])))
657
+
658
+ result.searchTimes.append(ensembleTimes[bestRecognizer])
659
+ if len(ensembleTimes[bestRecognizer]) > 0:
660
+ eprint("Average search time: ", int(mean(ensembleTimes[bestRecognizer]) + 0.5),
661
+ "sec.\tmedian:", int(median(ensembleTimes[bestRecognizer]) + 0.5),
662
+ "\tmax:", int(max(ensembleTimes[bestRecognizer]) + 0.5),
663
+ "\tstandard deviation", int(standardDeviation(ensembleTimes[bestRecognizer]) + 0.5))
664
+ return totalTasksHitBottomUp
665
+
666
+ def consolidate(result, grammar, _=None, topK=None, arity=None, pseudoCounts=None, aic=None,
667
+ structurePenalty=None, compressor=None, CPUs=None, iteration=None):
668
+ eprint("Showing the top 5 programs in each frontier being sent to the compressor:")
669
+ for f in result.allFrontiers.values():
670
+ if f.empty:
671
+ continue
672
+ eprint(f.task)
673
+ for e in f.normalize().topK(5):
674
+ eprint("%.02f\t%s" % (e.logPosterior, e.program))
675
+ eprint()
676
+
677
+ # First check if we have supervision at the program level for any task that was not solved
678
+ needToSupervise = {f.task for f in result.allFrontiers.values()
679
+ if f.task.supervision is not None and f.empty}
680
+ compressionFrontiers = [f.replaceWithSupervised(grammar) if f.task in needToSupervise else f
681
+ for f in result.allFrontiers.values() ]
682
+
683
+ if len([f for f in compressionFrontiers if not f.empty]) == 0:
684
+ eprint("No compression frontiers; not inducing a grammar this iteration.")
685
+ else:
686
+ grammar, compressionFrontiers = induceGrammar(grammar, compressionFrontiers,
687
+ topK=topK,
688
+ pseudoCounts=pseudoCounts, a=arity,
689
+ aic=aic, structurePenalty=structurePenalty,
690
+ topk_use_only_likelihood=False,
691
+ backend=compressor, CPUs=CPUs, iteration=iteration)
692
+ # Store compression frontiers in the result.
693
+ for c in compressionFrontiers:
694
+ result.allFrontiers[c.task] = c.topK(0) if c in needToSupervise else c
695
+
696
+
697
+ result.grammars.append(grammar)
698
+ eprint("Grammar after iteration %d:" % (iteration + 1))
699
+ eprint(grammar)
700
+
701
+ return grammar
702
+
703
+
704
+
705
+ def commandlineArguments(_=None,
706
+ iterations=None,
707
+ enumerationTimeout=None,
708
+ testEvery=1,
709
+ topK=1,
710
+ reuseRecognition=False,
711
+ CPUs=1,
712
+ solver='ocaml',
713
+ compressor="ocaml",
714
+ useRecognitionModel=True,
715
+ recognitionTimeout=None,
716
+ activation='relu',
717
+ helmholtzRatio=1.,
718
+ featureExtractor=None,
719
+ cuda=None,
720
+ maximumFrontier=None,
721
+ pseudoCounts=1.0, aic=1.0,
722
+ structurePenalty=0.001, a=0,
723
+ taskBatchSize=None, taskReranker="default",
724
+ extras=None,
725
+ storeTaskMetrics=False,
726
+ rewriteTaskMetrics=True):
727
+ if cuda is None:
728
+ cuda = torch.cuda.is_available()
729
+ print("CUDA is available?:", torch.cuda.is_available())
730
+ print("using cuda?:", cuda)
731
+ import argparse
732
+ parser = argparse.ArgumentParser(description="")
733
+ parser.add_argument("--resume",
734
+ help="Resumes EC algorithm from checkpoint. You can either pass in the path of a checkpoint, or you can pass in the iteration to resume from, in which case it will try to figure out the path.",
735
+ default=None,
736
+ type=str)
737
+ parser.add_argument("-i", "--iterations",
738
+ help="default: %d" % iterations,
739
+ default=iterations,
740
+ type=int)
741
+ parser.add_argument("-t", "--enumerationTimeout",
742
+ default=enumerationTimeout,
743
+ help="In seconds. default: %s" % enumerationTimeout,
744
+ type=int)
745
+ parser.add_argument("-R", "--recognitionTimeout",
746
+ default=recognitionTimeout,
747
+ help="In seconds. Amount of time to train the recognition model on each iteration. Defaults to enumeration timeout.",
748
+ type=int)
749
+ parser.add_argument("-RS", "--recognitionSteps",
750
+ default=None,
751
+ help="Number of gradient steps to train the recognition model. Can be specified instead of train time.",
752
+ type=int)
753
+ parser.add_argument(
754
+ "-k",
755
+ "--topK",
756
+ default=topK,
757
+ help="When training generative and discriminative models, we train them to fit the top K programs. Ideally we would train them to fit the entire frontier, but this is often intractable. default: %d" %
758
+ topK,
759
+ type=int)
760
+ parser.add_argument("-p", "--pseudoCounts",
761
+ default=pseudoCounts,
762
+ help="default: %f" % pseudoCounts,
763
+ type=float)
764
+ parser.add_argument("-b", "--aic",
765
+ default=aic,
766
+ help="default: %f" % aic,
767
+ type=float)
768
+ parser.add_argument("-l", "--structurePenalty",
769
+ default=structurePenalty,
770
+ help="default: %f" % structurePenalty,
771
+ type=float)
772
+ parser.add_argument("-a", "--arity",
773
+ default=a,
774
+ help="default: %d" % a,
775
+ type=int)
776
+ parser.add_argument("-c", "--CPUs",
777
+ default=CPUs,
778
+ help="default: %d" % CPUs,
779
+ type=int)
780
+ parser.add_argument("--no-cuda",
781
+ action="store_false",
782
+ dest="cuda",
783
+ help="""cuda will be used if available (which it %s),
784
+ unless this is set""" % ("IS" if cuda else "ISN'T"))
785
+ parser.add_argument("-m", "--maximumFrontier",
786
+ help="""Even though we enumerate --frontierSize
787
+ programs, we might want to only keep around the very
788
+ best for performance reasons. This is a cut off on the
789
+ maximum size of the frontier that is kept around.
790
+ Default: %s""" % maximumFrontier,
791
+ type=int)
792
+ parser.add_argument("--reuseRecognition",
793
+ help="""Should we initialize recognition model weights to be what they were at the previous DreamCoder iteration? Default: %s""" % reuseRecognition,
794
+ default=reuseRecognition,
795
+ action="store_true")
796
+ parser.add_argument("--recognition",
797
+ dest="useRecognitionModel",
798
+ action="store_true",
799
+ help="""Enable bottom-up neural recognition model.
800
+ Default: %s""" % useRecognitionModel)
801
+ parser.add_argument("--ensembleSize",
802
+ dest="ensembleSize",
803
+ default=1,
804
+ help="Number of recognition models to train and enumerate from at each iteration.",
805
+ type=int)
806
+ parser.add_argument("-g", "--no-recognition",
807
+ dest="useRecognitionModel",
808
+ action="store_false",
809
+ help="""Disable bottom-up neural recognition model.
810
+ Default: %s""" % (not useRecognitionModel))
811
+ parser.add_argument("-d", "--no-dsl",
812
+ dest="useDSL",
813
+ action="store_false",
814
+ help="""Disable DSL enumeration and updating.""")
815
+ parser.add_argument("--no-consolidation",
816
+ dest="noConsolidation",
817
+ action="store_true",
818
+ help="""Disable DSL updating.""")
819
+ parser.add_argument(
820
+ "--testingTimeout",
821
+ type=int,
822
+ dest="testingTimeout",
823
+ default=0,
824
+ help="Number of seconds we should spend evaluating on each held out testing task.")
825
+ parser.add_argument(
826
+ "--testEvery",
827
+ type=int,
828
+ dest="testEvery",
829
+ default=1,
830
+ help="Run heldout testing every X iterations."
831
+ )
832
+ parser.add_argument(
833
+ "--seed",
834
+ type=int,
835
+ default=0,
836
+ help="Random seed. Currently this only matters for random batching strategies.")
837
+ parser.add_argument(
838
+ "--activation",
839
+ choices=[
840
+ "relu",
841
+ "sigmoid",
842
+ "tanh"],
843
+ default=activation,
844
+ help="""Activation function for neural recognition model.
845
+ Default: %s""" %
846
+ activation)
847
+ parser.add_argument(
848
+ "--solver",
849
+ choices=[
850
+ "ocaml",
851
+ "pypy",
852
+ "python"],
853
+ default=solver,
854
+ help="""Solver for enumeration.
855
+ Default: %s""" %
856
+ solver)
857
+ parser.add_argument(
858
+ "-r",
859
+ "--Helmholtz",
860
+ dest="helmholtzRatio",
861
+ help="""When training recognition models, what fraction of the training data should be samples from the generative model? Default %f""" %
862
+ helmholtzRatio,
863
+ default=helmholtzRatio,
864
+ type=float)
865
+ parser.add_argument(
866
+ "--compressor",
867
+ default=compressor,
868
+ choices=["pypy","rust","vs","pypy_vs","ocaml","memorize"])
869
+ parser.add_argument(
870
+ "--matrixRank",
871
+ help="Maximum rank of bigram transition matrix for contextual recognition model. Defaults to full rank.",
872
+ default=None,
873
+ type=int)
874
+ parser.add_argument(
875
+ "--mask",
876
+ help="Unconditional bigram masking",
877
+ default=False, action="store_true")
878
+ parser.add_argument("--biasOptimal",
879
+ help="Enumerate dreams rather than sample them & bias-optimal recognition objective",
880
+ default=False, action="store_true")
881
+ parser.add_argument("--contextual",
882
+ help="bigram recognition model (default is unigram model)",
883
+ default=False, action="store_true")
884
+ parser.add_argument("--clear-recognition",
885
+ dest="clear-recognition",
886
+ help="Clears the recognition model from a checkpoint. Necessary for graphing results with recognition models, because pickle is kind of stupid sometimes.",
887
+ default=None,
888
+ type=str)
889
+ parser.add_argument("--primitive-graph",
890
+ dest="primitive-graph",
891
+ nargs='+',
892
+ help="Displays a dependency graph of the learned primitives",
893
+ default=None,
894
+ type=str)
895
+ parser.add_argument(
896
+ "--taskBatchSize",
897
+ dest="taskBatchSize",
898
+ help="Number of tasks to train on during wake. Defaults to all tasks if None.",
899
+ default=None,
900
+ type=int)
901
+ parser.add_argument(
902
+ "--taskReranker",
903
+ dest="taskReranker",
904
+ help="Reranking function used to order the tasks we train on during waking.",
905
+ choices=[
906
+ "default",
907
+ "random",
908
+ "randomShuffle",
909
+ "unsolved",
910
+ "unsolvedEntropy",
911
+ "unsolvedRandomEntropy",
912
+ "randomkNN",
913
+ "randomLowEntropykNN"],
914
+ default=taskReranker,
915
+ type=str)
916
+ parser.add_argument(
917
+ "--storeTaskMetrics",
918
+ dest="storeTaskMetrics",
919
+ default=True,
920
+ help="Whether to store task metrics directly in the ECResults.",
921
+ action="store_true"
922
+ )
923
+ parser.add_argument(
924
+ "--rewriteTaskMetrics",
925
+ dest="rewriteTaskMetrics",
926
+ help="Whether to rewrite a new task metrics dictionary at each iteration, rather than retaining the old.",
927
+ action="store_true"
928
+ )
929
+ parser.add_argument("--addTaskMetrics",
930
+ dest="addTaskMetrics",
931
+ help="Creates a checkpoint with task metrics and no recognition model for graphing.",
932
+ default=None,
933
+ nargs='+',
934
+ type=str)
935
+ parser.add_argument("--auxiliary",
936
+ action="store_true", default=False,
937
+ help="Add auxiliary classification loss to recognition network training",
938
+ dest="auxiliaryLoss")
939
+ parser.add_argument("--addFullTaskMetrics",
940
+ help="Only to be used in conjunction with --resume. Loads checkpoint, solves both testing and training tasks, stores frontiers, solve times, and task metrics, and then dies.",
941
+ default=False,
942
+ action="store_true")
943
+ parser.add_argument("--countParameters",
944
+ help="Load a checkpoint then report how many parameters are in the recognition model.",
945
+ default=None, type=str)
946
+ parser.set_defaults(useRecognitionModel=useRecognitionModel,
947
+ useDSL=True,
948
+ featureExtractor=featureExtractor,
949
+ maximumFrontier=maximumFrontier,
950
+ cuda=cuda)
951
+ if extras is not None:
952
+ extras(parser)
953
+ v = vars(parser.parse_args())
954
+ if v["clear-recognition"] is not None:
955
+ ECResult.clearRecognitionModel(v["clear-recognition"])
956
+ sys.exit(0)
957
+ else:
958
+ del v["clear-recognition"]
959
+
960
+ if v["primitive-graph"] is not None:
961
+
962
+ for n,pg in enumerate(v["primitive-graph"]):
963
+ with open(pg,'rb') as handle:
964
+ result = dill.load(handle)
965
+ graphPrimitives(result,f"figures/deepProgramLearning/{sys.argv[0]}{n}",view=True)
966
+ sys.exit(0)
967
+ else:
968
+ del v["primitive-graph"]
969
+
970
+ if v["addTaskMetrics"] is not None:
971
+ for path in v["addTaskMetrics"]:
972
+ with open(path,'rb') as handle:
973
+ result = dill.load(handle)
974
+ addTaskMetrics(result, path)
975
+ sys.exit(0)
976
+ else:
977
+ del v["addTaskMetrics"]
978
+
979
+ if v["useRecognitionModel"] and v["recognitionTimeout"] is None:
980
+ v["recognitionTimeout"] = v["enumerationTimeout"]
981
+
982
+ if v["countParameters"]:
983
+ with open(v["countParameters"], "rb") as handle:
984
+ result = dill.load(handle)
985
+ eprint("The recognition model has",
986
+ sum(p.numel() for p in result.recognitionModel.parameters() if p.requires_grad),
987
+ "trainable parameters and",
988
+ sum(p.numel() for p in result.recognitionModel.parameters() ),
989
+ "total parameters.\n",
990
+ "The feature extractor accounts for",
991
+ sum(p.numel() for p in result.recognitionModel.featureExtractor.parameters() ),
992
+ "of those parameters.\n",
993
+ "The grammar builder accounts for",
994
+ sum(p.numel() for p in result.recognitionModel.grammarBuilder.parameters() ),
995
+ "of those parameters.\n")
996
+ sys.exit(0)
997
+ del v["countParameters"]
998
+
999
+
1000
+ return v
1001
+
1002
+ def addTaskMetrics(result, path):
1003
+ """Adds a task metrics to ECResults that were pickled without them."""
1004
+ with torch.no_grad(): return addTaskMetrics_(result, path)
1005
+ def addTaskMetrics_(result, path):
1006
+ SUFFIX = '.pickle'
1007
+ assert path.endswith(SUFFIX)
1008
+
1009
+ tasks = result.taskSolutions.keys()
1010
+ everyTask = set(tasks)
1011
+ for t in result.recognitionTaskMetrics:
1012
+ if isinstance(t, Task) and t not in everyTask: everyTask.add(t)
1013
+
1014
+ eprint(f"Found {len(tasks)} training tasks.")
1015
+ eprint(f"Scrounged up {len(everyTask) - len(tasks)} testing tasks.")
1016
+ if not hasattr(result, "recognitionTaskMetrics") or result.recognitionTaskMetrics is None:
1017
+ result.recognitionTaskMetrics = {}
1018
+
1019
+ # If task has images, store them.
1020
+ if hasattr(list(tasks)[0], 'getImage'):
1021
+ images = {t: t.getImage(pretty=True) for t in tasks}
1022
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, images, 'taskImages')
1023
+
1024
+ if hasattr(list(tasks)[0], 'highresolution'):
1025
+ images = {t: t.highresolution for t in tasks}
1026
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, images, 'taskImages')
1027
+
1028
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.auxiliaryPrimitiveEmbeddings(), 'auxiliaryPrimitiveEmbeddings')
1029
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskAuxiliaryLossLayer(tasks), 'taskAuxiliaryLossLayer')
1030
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskAuxiliaryLossLayer(everyTask), 'every_auxiliaryLossLayer')
1031
+
1032
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarFeatureLogProductions(tasks), 'grammarFeatureLogProductions')
1033
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarFeatureLogProductions(everyTask), 'every_grammarFeatureLogProductions')
1034
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarLogProductions(tasks), 'contextualLogProductions')
1035
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarLogProductions(everyTask), 'every_contextualLogProductions')
1036
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskHiddenStates(tasks), 'hiddenState')
1037
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskHiddenStates(everyTask), 'every_hiddenState')
1038
+ g = result.grammars[-2] # the final entry in result.grammars is a grammar that we have not used yet
1039
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, {f.task: f.expectedProductionUses(g)
1040
+ for f in result.taskSolutions.values()
1041
+ if len(f) > 0},
1042
+ 'expectedProductionUses')
1043
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics, {f.task: f.expectedProductionUses(g)
1044
+ for t, metrics in result.recognitionTaskMetrics.items()
1045
+ if "frontier" in metrics
1046
+ for f in [metrics["frontier"]]
1047
+ if len(f) > 0},
1048
+ 'every_expectedProductionUses')
1049
+ if False:
1050
+ eprint(f"About to do an expensive Monte Carlo simulation w/ {len(tasks)} tasks")
1051
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics,
1052
+ {task: result.recognitionModel.grammarOfTask(task).untorch().expectedUsesMonteCarlo(task.request, debug=False)
1053
+ for task in tasks },
1054
+ 'expectedProductionUsesMonteCarlo')
1055
+ try:
1056
+ updateTaskSummaryMetrics(result.recognitionTaskMetrics,
1057
+ result.recognitionModel.taskGrammarStartProductions(tasks),
1058
+ 'startProductions')
1059
+ except: pass # can fail if we do not have a contextual model
1060
+
1061
+ #updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarLogProductions(tasks), 'task_no_parent_log_productions')
1062
+ #updateTaskSummaryMetrics(result.recognitionTaskMetrics, result.recognitionModel.taskGrammarEntropies(tasks), 'taskGrammarEntropies')
1063
+
1064
+ result.recognitionModel = None
1065
+
1066
+ clearedPath = path[:-len(SUFFIX)] + "_graph=True" + SUFFIX
1067
+ with open(clearedPath,'wb') as handle:
1068
+ result = dill.dump(result, handle)
1069
+ eprint(" [+] Cleared recognition model from:")
1070
+ eprint(" %s"%path)
1071
+ eprint(" and exported to:")
1072
+ eprint(" %s"%clearedPath)
1073
+ eprint(" Use this one for graphing.")
1074
+
dreamcoder/dreaming.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import subprocess
4
+
5
+ from dreamcoder.domains.arithmetic.arithmeticPrimitives import k1, k0, addition, subtraction, multiplication
6
+ from dreamcoder.frontier import Frontier, FrontierEntry
7
+ from dreamcoder.grammar import Grammar
8
+ from dreamcoder.program import Program
9
+ from dreamcoder.task import Task
10
+ from dreamcoder.type import arrow, tint
11
+ from dreamcoder.utilities import tuplify, timing, eprint, get_root_dir, mean
12
+
13
+
14
+ def helmholtzEnumeration(g, request, inputs, timeout, _=None,
15
+ special=None, evaluationTimeout=None):
16
+ """Returns json (as text)"""
17
+ message = {"request": request.json(),
18
+ "timeout": timeout,
19
+ "DSL": g.json(),
20
+ "extras": inputs}
21
+ if evaluationTimeout: message["evaluationTimeout"] = evaluationTimeout
22
+ if special: message["special"] = special
23
+ message = json.dumps(message)
24
+ with open('/tmp/hm', 'w') as handle:
25
+ handle.write(message)
26
+ try:
27
+ binary = os.path.join(get_root_dir(), 'helmholtz')
28
+ process = subprocess.Popen(binary,
29
+ stdin=subprocess.PIPE,
30
+ stdout=subprocess.PIPE)
31
+ response, error = process.communicate(bytes(message, encoding="utf-8"))
32
+ except OSError as exc:
33
+ raise exc
34
+ return response
35
+
36
+
37
+ def backgroundHelmholtzEnumeration(tasks, g, timeout, _=None,
38
+ special=None, evaluationTimeout=None):
39
+ from pathos.multiprocessing import Pool
40
+ requests = list({t.request for t in tasks})
41
+ inputs = {r: list({tuplify(xs)
42
+ for t in tasks if t.request == r
43
+ for xs, y in t.examples})
44
+ for r in requests}
45
+ workers = Pool(len(requests))
46
+ promises = [workers.apply_async(helmholtzEnumeration,
47
+ args=(g, r, inputs[r], float(timeout)),
48
+ kwds={'special': special,
49
+ 'evaluationTimeout': evaluationTimeout})
50
+ for r in requests]
51
+
52
+ def get():
53
+ results = [p.get() for p in promises]
54
+ frontiers = []
55
+ with timing("(Helmholtz enumeration) Decoded json into frontiers"):
56
+ for request, result in zip(requests, results):
57
+ response = json.loads(result.decode("utf-8"))
58
+ for b, entry in enumerate(response):
59
+ frontiers.append(Frontier([FrontierEntry(program=Program.parse(p),
60
+ logPrior=entry["ll"],
61
+ logLikelihood=0.)
62
+ for p in entry["programs"]],
63
+ task=Task(str(b),
64
+ request,
65
+ [])))
66
+ eprint("Total number of Helmholtz frontiers:", len(frontiers))
67
+ return frontiers
68
+
69
+ return get
70
+
71
+
72
+ if __name__ == "__main__":
73
+ from dreamcoder.recognition import RecognitionModel, DummyFeatureExtractor
74
+ g = Grammar.uniform([k1, k0, addition, subtraction, multiplication])
75
+ frontiers = helmholtzEnumeration(g,
76
+ arrow(tint, tint),
77
+ [[0], [1], [2]],
78
+ 10.)
79
+ eprint("average frontier size", mean(len(f.entries) for f in frontiers))
80
+ f = DummyFeatureExtractor([])
81
+ r = RecognitionModel(f, g, hidden=[], contextual=True)
82
+ r.trainBiasOptimal(frontiers, frontiers, steps=70)
83
+ g = r.grammarOfTask(frontiers[0].task).untorch()
84
+ frontiers = helmholtzEnumeration(g,
85
+ arrow(tint, tint),
86
+ [[0], [1], [2]],
87
+ 10.)
88
+ for f in frontiers:
89
+ eprint(f.summarizeFull())
90
+ eprint("average frontier size", mean(len(f.entries) for f in frontiers))
dreamcoder/ec.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ print("DEPRECATION NOTICE: this module (dreamcoder.ec) will be deleted soon, "
2
+ "please update your code to import from dreamcoder.dreamcoder instead")
3
+ from dreamcoder.dreamcoder import * # noqa
dreamcoder/enumeration.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.likelihoodModel import AllOrNothingLikelihoodModel
2
+ from dreamcoder.grammar import *
3
+ from dreamcoder.utilities import get_root_dir
4
+
5
+ import os
6
+ import traceback
7
+ import subprocess
8
+
9
+
10
+ def multicoreEnumeration(g, tasks, _=None,
11
+ enumerationTimeout=None,
12
+ solver='ocaml',
13
+ CPUs=1,
14
+ maximumFrontier=None,
15
+ verbose=True,
16
+ evaluationTimeout=None,
17
+ testing=False):
18
+ '''g: Either a Grammar, or a map from task to grammar.
19
+ Returns (list-of-frontiers, map-from-task-to-search-time)'''
20
+
21
+ # We don't use actual threads but instead use the multiprocessing
22
+ # library. This is because we need to be able to kill workers.
23
+ #from multiprocess import Process, Queue
24
+
25
+ from multiprocessing import Queue
26
+
27
+ # everything that gets sent between processes will be dilled
28
+ import dill
29
+
30
+ solvers = {"ocaml": solveForTask_ocaml,
31
+ "pypy": solveForTask_pypy,
32
+ "python": solveForTask_python}
33
+ assert solver in solvers, "You must specify a valid solver. options are ocaml, pypy, or python."
34
+
35
+ likelihoodModel = None
36
+ if solver == 'pypy' or solver == 'python':
37
+ # Use an all or nothing likelihood model.
38
+ likelihoodModel = AllOrNothingLikelihoodModel(timeout=evaluationTimeout)
39
+
40
+ solver = solvers[solver]
41
+
42
+ if not isinstance(g, dict):
43
+ g = {t: g for t in tasks}
44
+ task2grammar = g
45
+
46
+ # If we are not evaluating on held out testing tasks:
47
+ # Bin the tasks by request type and grammar
48
+ # If these are the same then we can enumerate for multiple tasks simultaneously
49
+ # If we are evaluating testing tasks:
50
+ # Make sure that each job corresponds to exactly one task
51
+ jobs = {}
52
+ for i, t in enumerate(tasks):
53
+ if testing:
54
+ k = (task2grammar[t], t.request, i)
55
+ else:
56
+ k = (task2grammar[t], t.request)
57
+ jobs[k] = jobs.get(k, []) + [t]
58
+
59
+ disableParallelism = len(jobs) == 1
60
+ parallelCallback = launchParallelProcess if not disableParallelism else lambda f, * \
61
+ a, **k: f(*a, **k)
62
+ if disableParallelism:
63
+ eprint("Disabling parallelism on the Python side because we only have one job.")
64
+ eprint("If you are using ocaml, there could still be parallelism.")
65
+
66
+ # Map from task to the shortest time to find a program solving it
67
+ bestSearchTime = {t: None for t in task2grammar}
68
+
69
+ lowerBounds = {k: 0. for k in jobs}
70
+
71
+ frontiers = {t: Frontier([], task=t) for t in task2grammar}
72
+
73
+ # For each job we keep track of how long we have been working on it
74
+ stopwatches = {t: Stopwatch() for t in jobs}
75
+
76
+ # Map from task to how many programs we enumerated for that task
77
+ taskToNumberOfPrograms = {t: 0 for t in tasks }
78
+
79
+ def numberOfHits(f):
80
+ return sum(e.logLikelihood > -0.01 for e in f)
81
+
82
+ def budgetIncrement(lb):
83
+ if True:
84
+ return 1.5
85
+ # Very heuristic - not sure what to do here
86
+ if lb < 24.:
87
+ return 1.
88
+ elif lb < 27.:
89
+ return 0.5
90
+ else:
91
+ return 0.25
92
+
93
+ def maximumFrontiers(j):
94
+ tasks = jobs[j]
95
+ return {t: maximumFrontier - numberOfHits(frontiers[t]) for t in tasks}
96
+
97
+ def allocateCPUs(n, tasks):
98
+ allocation = {t: 0 for t in tasks}
99
+ while n > 0:
100
+ for t in tasks:
101
+ # During testing we use exactly one CPU per task
102
+ if testing and allocation[t] > 0:
103
+ return allocation
104
+ allocation[t] += 1
105
+ n -= 1
106
+ if n == 0:
107
+ break
108
+ return allocation
109
+
110
+ def refreshJobs():
111
+ for k in list(jobs.keys()):
112
+ v = [t for t in jobs[k]
113
+ if numberOfHits(frontiers[t]) < maximumFrontier
114
+ and stopwatches[k].elapsed <= enumerationTimeout]
115
+ if v:
116
+ jobs[k] = v
117
+ else:
118
+ del jobs[k]
119
+
120
+ # Workers put their messages in here
121
+ q = Queue()
122
+
123
+ # How many CPUs are we using?
124
+ activeCPUs = 0
125
+
126
+ # How many CPUs was each job allocated?
127
+ id2CPUs = {}
128
+ # What job was each ID working on?
129
+ id2job = {}
130
+ nextID = 0
131
+
132
+ while True:
133
+ refreshJobs()
134
+ # Don't launch a job that we are already working on
135
+ # We run the stopwatch whenever the job is being worked on
136
+ # freeJobs are things that we are not working on but could be
137
+ freeJobs = [j for j in jobs if not stopwatches[j].running
138
+ and stopwatches[j].elapsed < enumerationTimeout - 0.5]
139
+ if freeJobs and activeCPUs < CPUs:
140
+ # Allocate a CPU to each of the jobs that we have made the least
141
+ # progress on
142
+ freeJobs.sort(key=lambda j: lowerBounds[j])
143
+ # Launch some more jobs until all of the CPUs are being used
144
+ availableCPUs = CPUs - activeCPUs
145
+ allocation = allocateCPUs(availableCPUs, freeJobs)
146
+ for j in freeJobs:
147
+ if allocation[j] == 0:
148
+ continue
149
+ g, request = j[:2]
150
+ bi = budgetIncrement(lowerBounds[j])
151
+ thisTimeout = enumerationTimeout - stopwatches[j].elapsed
152
+ eprint("(python) Launching %s (%d tasks) w/ %d CPUs. %f <= MDL < %f. Timeout %f." %
153
+ (request, len(jobs[j]), allocation[j], lowerBounds[j], lowerBounds[j] + bi, thisTimeout))
154
+ stopwatches[j].start()
155
+ parallelCallback(wrapInThread(solver),
156
+ q=q, g=g, ID=nextID,
157
+ elapsedTime=stopwatches[j].elapsed,
158
+ CPUs=allocation[j],
159
+ tasks=jobs[j],
160
+ lowerBound=lowerBounds[j],
161
+ upperBound=lowerBounds[j] + bi,
162
+ budgetIncrement=bi,
163
+ timeout=thisTimeout,
164
+ evaluationTimeout=evaluationTimeout,
165
+ maximumFrontiers=maximumFrontiers(j),
166
+ testing=testing,
167
+ likelihoodModel=likelihoodModel)
168
+ id2CPUs[nextID] = allocation[j]
169
+ id2job[nextID] = j
170
+ nextID += 1
171
+
172
+ activeCPUs += allocation[j]
173
+ lowerBounds[j] += bi
174
+
175
+ # If nothing is running, and we just tried to launch jobs,
176
+ # then that means we are finished
177
+ if all(not s.running for s in stopwatches.values()):
178
+ break
179
+
180
+ # Wait to get a response
181
+ message = Bunch(dill.loads(q.get()))
182
+
183
+ if message.result == "failure":
184
+ eprint("PANIC! Exception in child worker:", message.exception)
185
+ eprint(message.stacktrace)
186
+ assert False
187
+ elif message.result == "success":
188
+ # Mark the CPUs is no longer being used and pause the stopwatch
189
+ activeCPUs -= id2CPUs[message.ID]
190
+ stopwatches[id2job[message.ID]].stop()
191
+
192
+ newFrontiers, searchTimes, pc = message.value
193
+ for t, f in newFrontiers.items():
194
+ oldBest = None if len(
195
+ frontiers[t]) == 0 else frontiers[t].bestPosterior
196
+ frontiers[t] = frontiers[t].combine(f)
197
+ newBest = None if len(
198
+ frontiers[t]) == 0 else frontiers[t].bestPosterior
199
+
200
+ taskToNumberOfPrograms[t] += pc
201
+
202
+ dt = searchTimes[t]
203
+ if dt is not None:
204
+ if bestSearchTime[t] is None:
205
+ bestSearchTime[t] = dt
206
+ else:
207
+ # newBest & oldBest should both be defined
208
+ assert oldBest is not None
209
+ assert newBest is not None
210
+ newScore = newBest.logPrior + newBest.logLikelihood
211
+ oldScore = oldBest.logPrior + oldBest.logLikelihood
212
+
213
+ if newScore > oldScore:
214
+ bestSearchTime[t] = dt
215
+ elif newScore == oldScore:
216
+ bestSearchTime[t] = min(bestSearchTime[t], dt)
217
+ else:
218
+ eprint("Unknown message result:", message.result)
219
+ assert False
220
+
221
+ eprint("We enumerated this many programs, for each task:\n\t",
222
+ list(taskToNumberOfPrograms.values()))
223
+
224
+ return [frontiers[t] for t in tasks], bestSearchTime
225
+
226
+ def wrapInThread(f):
227
+ """
228
+ Returns a function that is designed to be run in a thread/threadlike process.
229
+ Result will be either put into the q
230
+ """
231
+ import dill
232
+
233
+ def _f(*a, **k):
234
+ q = k.pop("q")
235
+ ID = k.pop("ID")
236
+
237
+ try:
238
+ r = f(*a, **k)
239
+ q.put(dill.dumps({"result": "success",
240
+ "ID": ID,
241
+ "value": r}))
242
+ except Exception as e:
243
+ q.put(dill.dumps({"result": "failure",
244
+ "exception": e,
245
+ "stacktrace": traceback.format_exc(),
246
+ "ID": ID}))
247
+ return
248
+ return _f
249
+
250
+
251
+ def solveForTask_ocaml(_=None,
252
+ elapsedTime=0.,
253
+ CPUs=1,
254
+ g=None, tasks=None,
255
+ lowerBound=None, upperBound=None, budgetIncrement=None,
256
+ timeout=None,
257
+ testing=None, # FIXME: unused
258
+ likelihoodModel=None,
259
+ evaluationTimeout=None, maximumFrontiers=None):
260
+
261
+ import json
262
+
263
+ def taskMessage(t):
264
+ m = {
265
+ "examples": [{"inputs": list(xs), "output": y} for xs, y in t.examples],
266
+ "name": t.name,
267
+ "request": t.request.json(),
268
+ "maximumFrontier": maximumFrontiers[t]}
269
+ if hasattr(t, "specialTask"):
270
+ special, extra = t.specialTask
271
+ m["specialTask"] = special
272
+ m["extras"] = extra
273
+ return m
274
+
275
+
276
+ message = {"DSL": g.json(),
277
+ "tasks": [taskMessage(t)
278
+ for t in tasks],
279
+
280
+ "programTimeout": evaluationTimeout,
281
+ "nc": CPUs,
282
+ "timeout": timeout,
283
+ "lowerBound": lowerBound,
284
+ "upperBound": upperBound,
285
+ "budgetIncrement": budgetIncrement,
286
+ "verbose": False,
287
+ "shatter": 5 if len(tasks) == 1 and "turtle" in str(tasks[0].request) else 10}
288
+
289
+ if hasattr(tasks[0], 'maxParameters') and tasks[0].maxParameters is not None:
290
+ message["maxParameters"] = tasks[0].maxParameters
291
+
292
+ message = json.dumps(message)
293
+ # uncomment this if you want to save the messages being sent to the solver
294
+
295
+
296
+ try:
297
+ solver_file = os.path.join(get_root_dir(), 'solver')
298
+ process = subprocess.Popen(solver_file,
299
+ stdin=subprocess.PIPE,
300
+ stdout=subprocess.PIPE)
301
+ response, error = process.communicate(bytes(message, encoding="utf-8"))
302
+ response = json.loads(response.decode("utf-8"))
303
+ except OSError as exc:
304
+ raise exc
305
+
306
+ except:
307
+ print("response:", response)
308
+ print("error:", error)
309
+ with open("message", "w") as f:
310
+ f.write(message)
311
+ print("message,", message)
312
+ assert False, "MAX RAISE"
313
+
314
+
315
+ pc = response.get("number_enumerated",0) # TODO
316
+ frontiers = {}
317
+ searchTimes = {}
318
+ for t in tasks:
319
+ solutions = response[t.name]
320
+ frontier = Frontier([FrontierEntry(program=p,
321
+ logLikelihood=e["logLikelihood"],
322
+ logPrior=g.logLikelihood(t.request, p))
323
+ for e in solutions
324
+ for p in [Program.parse(e["program"])]],
325
+ task=t)
326
+ frontiers[t] = frontier
327
+ if frontier.empty:
328
+ searchTimes[t] = None
329
+ # This is subtle:
330
+ # The search time we report is actually not be minimum time to find any solution
331
+ # Rather it is the time to find the MAP solution
332
+ # This is important for regression problems,
333
+ # where we might find something with a good prior but bad likelihood early on,
334
+ # and only later discovered the good high likelihood program
335
+ else:
336
+ searchTimes[t] = min(
337
+ (e["logLikelihood"] + e["logPrior"],
338
+ e["time"]) for e in solutions)[1] + elapsedTime
339
+
340
+ return frontiers, searchTimes, pc
341
+
342
+ def solveForTask_pypy(_=None,
343
+ elapsedTime=0.,
344
+ g=None, task=None,
345
+ lowerBound=None, upperBound=None, budgetIncrement=None,
346
+ timeout=None,
347
+ likelihoodModel=None,
348
+ evaluationTimeout=None, maximumFrontier=None, testing=False):
349
+ return callCompiled(enumerateForTasks,
350
+ g, tasks, likelihoodModel,
351
+ timeout=timeout,
352
+ testing=testing,
353
+ elapsedTime=elapsedTime,
354
+ evaluationTimeout=evaluationTimeout,
355
+ maximumFrontiers=maximumFrontiers,
356
+ budgetIncrement=budgetIncrement,
357
+ lowerBound=lowerBound, upperBound=upperBound)
358
+
359
+ def solveForTask_python(_=None,
360
+ elapsedTime=0.,
361
+ g=None, tasks=None,
362
+ lowerBound=None, upperBound=None, budgetIncrement=None,
363
+ timeout=None,
364
+ CPUs=1,
365
+ likelihoodModel=None,
366
+ evaluationTimeout=None, maximumFrontiers=None, testing=False):
367
+ return enumerateForTasks(g, tasks, likelihoodModel,
368
+ timeout=timeout,
369
+ testing=testing,
370
+ elapsedTime=elapsedTime,
371
+ evaluationTimeout=evaluationTimeout,
372
+ maximumFrontiers=maximumFrontiers,
373
+ budgetIncrement=budgetIncrement,
374
+ lowerBound=lowerBound, upperBound=upperBound)
375
+
376
+
377
+ class EnumerationTimeout(Exception):
378
+ pass
379
+
380
+ def enumerateForTasks(g, tasks, likelihoodModel, _=None,
381
+ verbose=False,
382
+ timeout=None,
383
+ elapsedTime=0.,
384
+ CPUs=1,
385
+ testing=False, #unused
386
+ evaluationTimeout=None,
387
+ lowerBound=0.,
388
+ upperBound=100.,
389
+ budgetIncrement=1.0, maximumFrontiers=None):
390
+ assert timeout is not None, \
391
+ "enumerateForTasks: You must provide a timeout."
392
+
393
+ from time import time
394
+
395
+ request = tasks[0].request
396
+ assert all(t.request == request for t in tasks), \
397
+ "enumerateForTasks: Expected tasks to all have the same type"
398
+
399
+ maximumFrontiers = [maximumFrontiers[t] for t in tasks]
400
+ # store all of the hits in a priority queue
401
+ # we will never maintain maximumFrontier best solutions
402
+ hits = [PQ() for _ in tasks]
403
+
404
+ starting = time()
405
+ previousBudget = lowerBound
406
+ budget = lowerBound + budgetIncrement
407
+ try:
408
+ totalNumberOfPrograms = 0
409
+ while time() < starting + timeout and \
410
+ any(len(h) < mf for h, mf in zip(hits, maximumFrontiers)) and \
411
+ budget <= upperBound:
412
+ numberOfPrograms = 0
413
+
414
+ for prior, _, p in g.enumeration(Context.EMPTY, [], request,
415
+ maximumDepth=99,
416
+ upperBound=budget,
417
+ lowerBound=previousBudget):
418
+ descriptionLength = -prior
419
+ # Shouldn't see it on this iteration
420
+ assert descriptionLength <= budget
421
+ # Should already have seen it
422
+ assert descriptionLength > previousBudget
423
+
424
+ numberOfPrograms += 1
425
+ totalNumberOfPrograms += 1
426
+
427
+ for n in range(len(tasks)):
428
+ task = tasks[n]
429
+
430
+ #Warning:changed to max's new likelihood model situation
431
+ #likelihood = task.logLikelihood(p, evaluationTimeout)
432
+ #if invalid(likelihood):
433
+ #continue
434
+ success, likelihood = likelihoodModel.score(p, task)
435
+ if not success:
436
+ continue
437
+
438
+ dt = time() - starting + elapsedTime
439
+ priority = -(likelihood + prior)
440
+ hits[n].push(priority,
441
+ (dt, FrontierEntry(program=p,
442
+ logLikelihood=likelihood,
443
+ logPrior=prior)))
444
+ if len(hits[n]) > maximumFrontiers[n]:
445
+ hits[n].popMaximum()
446
+
447
+ if timeout is not None and time() - starting > timeout:
448
+ raise EnumerationTimeout
449
+
450
+ previousBudget = budget
451
+ budget += budgetIncrement
452
+
453
+ if budget > upperBound:
454
+ break
455
+ except EnumerationTimeout:
456
+ pass
457
+ frontiers = {tasks[n]: Frontier([e for _, e in hits[n]],
458
+ task=tasks[n])
459
+ for n in range(len(tasks))}
460
+ searchTimes = {
461
+ tasks[n]: None if len(hits[n]) == 0 else \
462
+ min(t for t,_ in hits[n]) for n in range(len(tasks))}
463
+
464
+ return frontiers, searchTimes, totalNumberOfPrograms
465
+
466
+
467
+
468
+
469
+
dreamcoder/fragmentGrammar.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.fragmentUtilities import *
2
+ from dreamcoder.grammar import *
3
+ from dreamcoder.program import *
4
+
5
+ from itertools import chain
6
+ import time
7
+
8
+
9
+ class FragmentGrammar(object):
10
+ def __init__(self, logVariable, productions):
11
+ self.logVariable = logVariable
12
+ self.productions = productions
13
+ self.likelihoodCache = {}
14
+
15
+ def clearCache(self):
16
+ self.likelihoodCache = {}
17
+
18
+ def __repr__(self):
19
+ return "FragmentGrammar(logVariable={self.logVariable}, productions={self.productions}".format(
20
+ self=self)
21
+
22
+ def __str__(self):
23
+ def productionKey(xxx_todo_changeme):
24
+ (l, t, p) = xxx_todo_changeme
25
+ return not isinstance(p, Primitive), -l
26
+ return "\n".join(["%f\tt0\t$_" % self.logVariable] + ["%f\t%s\t%s" % (l, t, p)
27
+ for l, t, p in sorted(self.productions, key=productionKey)])
28
+
29
+ def buildCandidates(self, context, environment, request):
30
+ candidates = []
31
+ variableCandidates = []
32
+ for l, t, p in self.productions:
33
+ try:
34
+ newContext, t = t.instantiate(context)
35
+ newContext = newContext.unify(t.returns(), request)
36
+ candidates.append((l, newContext,
37
+ t.apply(newContext),
38
+ p))
39
+ except UnificationFailure:
40
+ continue
41
+ for j, t in enumerate(environment):
42
+ try:
43
+ newContext = context.unify(t.returns(), request)
44
+ variableCandidates.append((newContext,
45
+ t.apply(newContext),
46
+ Index(j)))
47
+ except UnificationFailure:
48
+ continue
49
+ if variableCandidates:
50
+ z = math.log(len(variableCandidates))
51
+ for newContext, newType, index in variableCandidates:
52
+ candidates.append(
53
+ (self.logVariable - z, newContext, newType, index))
54
+
55
+ z = lse([candidate[0] for candidate in candidates])
56
+ return [(l - z, c, t, p) for l, c, t, p in candidates]
57
+
58
+ def logLikelihood(self, request, expression):
59
+ _, l, _ = self._logLikelihood(Context.EMPTY, [], request, expression)
60
+ if invalid(l):
61
+ f = 'failures/likelihoodFailure%s.pickle' % (time() + getPID())
62
+ eprint("PANIC: Invalid log likelihood. expression:",
63
+ expression, "tp:", request, "Exported to:", f)
64
+ with open(f, 'wb') as handle:
65
+ pickle.dump((self, request, expression), handle)
66
+ assert False
67
+ return l
68
+
69
+ def closedUses(self, request, expression):
70
+ _, l, u = self._logLikelihood(Context.EMPTY, [], request, expression)
71
+ return l, u
72
+
73
+ def _logLikelihood(self, context, environment, request, expression):
74
+ '''returns (context, log likelihood, uses)'''
75
+
76
+ # We can cash likelihood calculations faster whenever they don't involve type inference
77
+ # This is because they are guaranteed to not modify the context,
78
+ polymorphic = request.isPolymorphic or any(
79
+ v.isPolymorphic for v in environment)
80
+ # For some reason polymorphic caching slows it down
81
+ shouldDoCaching = not polymorphic
82
+
83
+ # Caching
84
+ if shouldDoCaching:
85
+ if polymorphic:
86
+ inTypes = canonicalTypes(
87
+ [request.apply(context)] + [v.apply(context) for v in environment])
88
+ else:
89
+ inTypes = canonicalTypes([request] + environment)
90
+ cacheKey = (tuple(inTypes), expression)
91
+ if cacheKey in self.likelihoodCache:
92
+ outTypes, l, u = self.likelihoodCache[cacheKey]
93
+ context, instantiatedTypes = instantiateTypes(
94
+ context, outTypes)
95
+ outRequest = instantiatedTypes[0]
96
+ outEnvironment = instantiatedTypes[1:]
97
+ # eprint("request:", request.apply(context), "environment:",
98
+ # [ v.apply(context) for v in environment ])
99
+ # eprint("will be unified with: out request:",outRequest,"out environment",outEnvironment)
100
+ if polymorphic:
101
+ context = context.unify(request, outRequest)
102
+ for v, vp in zip(environment, outEnvironment):
103
+ context = context.unify(v, vp)
104
+ return context, l, u
105
+
106
+ if request.isArrow():
107
+ if not isinstance(expression, Abstraction):
108
+ return (context, NEGATIVEINFINITY, Uses.empty)
109
+ return self._logLikelihood(context,
110
+ [request.arguments[0]] + environment,
111
+ request.arguments[1],
112
+ expression.body)
113
+
114
+ # Not a function type
115
+
116
+ # Construct and normalize the candidate productions
117
+ candidates = self.buildCandidates(context, environment, request)
118
+
119
+ # Consider each way of breaking the expression up into a
120
+ # function and arguments
121
+ totalLikelihood = NEGATIVEINFINITY
122
+ weightedUses = []
123
+
124
+ possibleVariables = float(int(any(isinstance(candidate, Index)
125
+ for _, _, _, candidate in candidates)))
126
+ possibleUses = {candidate: 1. for _, _, _, candidate in candidates
127
+ if not isinstance(candidate, Index)}
128
+
129
+ for f, xs in expression.applicationParses():
130
+ for candidateLikelihood, newContext, tp, production in candidates:
131
+ variableBindings = {}
132
+ # This is a variable in the environment
133
+ if production.isIndex:
134
+ if production != f:
135
+ continue
136
+ else:
137
+ try:
138
+ newContext, fragmentType, variableBindings = \
139
+ Matcher.match(newContext, production, f, len(xs))
140
+ # This is necessary because the types of the variable
141
+ # bindings and holes need to match up w/ request
142
+ fragmentTypeTemplate = request
143
+ for _ in xs:
144
+ newContext, newVariable = newContext.makeVariable()
145
+ fragmentTypeTemplate = arrow(
146
+ newVariable, fragmentTypeTemplate)
147
+ newContext = newContext.unify(
148
+ fragmentType, fragmentTypeTemplate)
149
+ # update the unified type
150
+ tp = fragmentType.apply(newContext)
151
+ except MatchFailure:
152
+ continue
153
+
154
+ argumentTypes = tp.functionArguments()
155
+ if len(xs) != len(argumentTypes):
156
+ # I think that this is some kind of bug. But I can't figure it out right now.
157
+ # As a hack, count this as though it were a failure
158
+ continue
159
+ #raise GrammarFailure('len(xs) != len(argumentTypes): tp={}, xs={}'.format(tp, xs))
160
+
161
+ thisLikelihood = candidateLikelihood
162
+ if isinstance(production, Index):
163
+ theseUses = Uses(possibleVariables=possibleVariables,
164
+ actualVariables=1.,
165
+ possibleUses=possibleUses.copy(),
166
+ actualUses={})
167
+ else:
168
+ theseUses = Uses(possibleVariables=possibleVariables,
169
+ actualVariables=0.,
170
+ possibleUses=possibleUses.copy(),
171
+ actualUses={production: 1.})
172
+
173
+ # Accumulate likelihood from free variables and holes and
174
+ # arguments
175
+ for freeType, freeExpression in chain(
176
+ variableBindings.values(), zip(argumentTypes, xs)):
177
+ freeType = freeType.apply(newContext)
178
+ newContext, expressionLikelihood, newUses = self._logLikelihood(
179
+ newContext, environment, freeType, freeExpression)
180
+ if expressionLikelihood is NEGATIVEINFINITY:
181
+ thisLikelihood = NEGATIVEINFINITY
182
+ break
183
+
184
+ thisLikelihood += expressionLikelihood
185
+ theseUses += newUses
186
+
187
+ if thisLikelihood is NEGATIVEINFINITY:
188
+ continue
189
+
190
+ weightedUses.append((thisLikelihood, theseUses))
191
+ totalLikelihood = lse(totalLikelihood, thisLikelihood)
192
+
193
+ # Any of these new context objects should be equally good
194
+ context = newContext
195
+
196
+ if totalLikelihood is NEGATIVEINFINITY:
197
+ return context, totalLikelihood, Uses.empty
198
+ assert weightedUses != []
199
+
200
+ allUses = Uses.join(totalLikelihood, *weightedUses)
201
+
202
+ # memoize result
203
+ if shouldDoCaching:
204
+ outTypes = [request.apply(context)] + \
205
+ [v.apply(context) for v in environment]
206
+ outTypes = canonicalTypes(outTypes)
207
+ self.likelihoodCache[cacheKey] = (
208
+ outTypes, totalLikelihood, allUses)
209
+
210
+ return context, totalLikelihood, allUses
211
+
212
+ def expectedUses(self, frontiers):
213
+ if len(list(frontiers)) == 0:
214
+ return Uses()
215
+ likelihoods = [[(l + entry.logLikelihood, u)
216
+ for entry in frontier
217
+ for l, u in [self.closedUses(frontier.task.request, entry.program)]]
218
+ for frontier in frontiers]
219
+ zs = (lse([l for l, _ in ls]) for ls in likelihoods)
220
+ return sum(math.exp(l - z) * u
221
+ for z, frontier in zip(zs, likelihoods)
222
+ for l, u in frontier)
223
+
224
+ def insideOutside(self, frontiers, pseudoCounts):
225
+ uses = self.expectedUses(frontiers)
226
+ return FragmentGrammar(log(uses.actualVariables +
227
+ pseudoCounts) -
228
+ log(max(uses.possibleVariables, 1.)), [(log(uses.actualUses.get(p, 0.) +
229
+ pseudoCounts) -
230
+ log(uses.possibleUses.get(p, 0.) +
231
+ pseudoCounts), t, p) for _, t, p in self.productions])
232
+
233
+ def jointFrontiersLikelihood(self, frontiers):
234
+ return sum(lse([entry.logLikelihood + self.logLikelihood(frontier.task.request, entry.program)
235
+ for entry in frontier])
236
+ for frontier in frontiers)
237
+
238
+ def jointFrontiersMDL(self, frontiers, CPUs=1):
239
+ return sum(
240
+ parallelMap(
241
+ CPUs,
242
+ lambda frontier: max(
243
+ entry.logLikelihood +
244
+ self.logLikelihood(
245
+ frontier.task.request,
246
+ entry.program) for entry in frontier),
247
+ frontiers))
248
+
249
+ def __len__(self): return len(self.productions)
250
+
251
+ @staticmethod
252
+ def fromGrammar(g):
253
+ return FragmentGrammar(g.logVariable, g.productions)
254
+
255
+ def toGrammar(self):
256
+ return Grammar(self.logVariable, [(l, q.infer(), q)
257
+ for l, t, p in self.productions
258
+ for q in [defragment(p)]])
259
+
260
+ @property
261
+ def primitives(self): return [p for _, _, p in self.productions]
262
+
263
+ @staticmethod
264
+ def uniform(productions):
265
+ return FragmentGrammar(0., [(0., p.infer(), p) for p in productions])
266
+
267
+ def normalize(self):
268
+ z = lse([l for l, t, p in self.productions] + [self.logVariable])
269
+ return FragmentGrammar(self.logVariable - z,
270
+ [(l - z, t, p) for l, t, p in self.productions])
271
+
272
+ def makeUniform(self):
273
+ return FragmentGrammar(0., [(0., p.infer(), p)
274
+ for _, _, p in self.productions])
275
+
276
+ def rescoreFrontier(self, frontier):
277
+ return Frontier([FrontierEntry(e.program,
278
+ logPrior=self.logLikelihood(frontier.task.request, e.program),
279
+ logLikelihood=e.logLikelihood)
280
+ for e in frontier],
281
+ frontier.task)
282
+
283
+ @staticmethod
284
+ def induceFromFrontiers(
285
+ g0,
286
+ frontiers,
287
+ _=None,
288
+ topK=1,
289
+ topk_use_only_likelihood=False,
290
+ pseudoCounts=1.0,
291
+ aic=1.0,
292
+ structurePenalty=0.001,
293
+ a=0,
294
+ CPUs=1):
295
+ _ = topk_use_only_likelihood # not used in python compressor
296
+ originalFrontiers = frontiers
297
+ frontiers = [frontier for frontier in frontiers if not frontier.empty]
298
+ eprint("Inducing a grammar from", len(frontiers), "frontiers")
299
+
300
+ bestGrammar = FragmentGrammar.fromGrammar(g0)
301
+ oldJoint = bestGrammar.jointFrontiersMDL(frontiers, CPUs=1)
302
+
303
+ # "restricted frontiers" only contain the top K according to the best grammar
304
+ def restrictFrontiers():
305
+ return parallelMap(
306
+ CPUs,
307
+ lambda f: bestGrammar.rescoreFrontier(f).topK(topK),
308
+ frontiers)
309
+ restrictedFrontiers = []
310
+
311
+ def grammarScore(g):
312
+ g = g.makeUniform().insideOutside(restrictedFrontiers, pseudoCounts)
313
+ likelihood = g.jointFrontiersMDL(restrictedFrontiers)
314
+ structure = sum(primitiveSize(p) for p in g.primitives)
315
+ score = likelihood - aic * len(g) - structurePenalty * structure
316
+ g.clearCache()
317
+ if invalid(score):
318
+ # FIXME: This should never occur but it does anyway
319
+ score = float('-inf')
320
+ return score, g
321
+
322
+ if aic is not POSITIVEINFINITY:
323
+ restrictedFrontiers = restrictFrontiers()
324
+ bestScore, _ = grammarScore(bestGrammar)
325
+ eprint("Starting score", bestScore)
326
+ while True:
327
+ restrictedFrontiers = restrictFrontiers()
328
+ fragments = [f
329
+ for f in proposeFragmentsFromFrontiers(restrictedFrontiers, a, CPUs=CPUs)
330
+ if f not in bestGrammar.primitives
331
+ and defragment(f) not in bestGrammar.primitives]
332
+ eprint("Proposed %d fragments." % len(fragments))
333
+
334
+ candidateGrammars = [
335
+ FragmentGrammar.uniform(
336
+ bestGrammar.primitives +
337
+ [fragment]) for fragment in fragments]
338
+ if not candidateGrammars:
339
+ break
340
+
341
+ scoredFragments = parallelMap(CPUs, grammarScore, candidateGrammars,
342
+ # Each process handles up to 100
343
+ # grammars at a time, a "job"
344
+ chunksize=max(
345
+ 1, min(len(candidateGrammars) // CPUs, 100)),
346
+ # maxTasks: Maximum number of jobs allocated to a process
347
+ # This means that after evaluating this*chunk many grammars,
348
+ # we killed the process, freeing up its memory.
349
+ # In exchange we pay the cost of spawning a new process.
350
+ # We should play with this number,
351
+ # figuring out how big we can make it without
352
+ # running out of memory.
353
+ maxtasksperchild=5)
354
+ newScore, newGrammar = max(scoredFragments, key=lambda sg: sg[0])
355
+
356
+ if newScore <= bestScore:
357
+ break
358
+ dS = newScore - bestScore
359
+ bestScore, bestGrammar = newScore, newGrammar
360
+ newPrimitiveLikelihood, newType, newPrimitive = bestGrammar.productions[-1]
361
+ expectedUses = bestGrammar.expectedUses(
362
+ restrictedFrontiers).actualUses.get(newPrimitive, 0)
363
+ eprint(
364
+ "New primitive of type %s\t%s\t\n(score = %f; dScore = %f; <uses> = %f)" %
365
+ (newType, newPrimitive, newScore, dS, expectedUses))
366
+
367
+ # Rewrite the frontiers in terms of the new fragment
368
+ concretePrimitive = defragment(newPrimitive)
369
+ bestGrammar.productions[-1] = (newPrimitiveLikelihood,
370
+ concretePrimitive.tp,
371
+ concretePrimitive)
372
+ frontiers = parallelMap(
373
+ CPUs, lambda frontier: bestGrammar.rescoreFrontier(
374
+ RewriteFragments.rewriteFrontier(
375
+ frontier, newPrimitive)), frontiers)
376
+ eprint(
377
+ "\t(<uses> in rewritten frontiers: %f)" %
378
+ (bestGrammar.expectedUses(frontiers).actualUses[concretePrimitive]))
379
+ else:
380
+ eprint("Skipping fragment proposals")
381
+
382
+ if False:
383
+ # Reestimate the parameters using the entire frontiers
384
+ bestGrammar = bestGrammar.makeUniform().insideOutside(frontiers, pseudoCounts)
385
+ elif True:
386
+ # Reestimate the parameters using the best programs
387
+ restrictedFrontiers = restrictFrontiers()
388
+ bestGrammar = bestGrammar.makeUniform().insideOutside(
389
+ restrictedFrontiers, pseudoCounts)
390
+ else:
391
+ # Use parameters that were found during search
392
+ pass
393
+
394
+ eprint("Old joint = %f\tNew joint = %f\n" %
395
+ (oldJoint, bestGrammar.jointFrontiersMDL(frontiers, CPUs=CPUs)))
396
+ # Return all of the frontiers, which have now been rewritten to use the
397
+ # new fragments
398
+ frontiers = {f.task: f for f in frontiers}
399
+ frontiers = [frontiers.get(f.task, f)
400
+ for f in originalFrontiers]
401
+
402
+ productionUses = bestGrammar.expectedUses(
403
+ [f for f in frontiers if not f.empty]).actualUses
404
+ productionUses = {
405
+ p: productionUses.get(
406
+ p, 0.) for p in bestGrammar.primitives}
407
+ possibleUses = bestGrammar.expectedUses(
408
+ [f for f in frontiers if not f.empty]).possibleUses
409
+ possibleUses = {
410
+ p: possibleUses.get(
411
+ p, 0.) for p in bestGrammar.primitives}
412
+
413
+ for p in bestGrammar.primitives:
414
+ eprint("%f / %f\t%s" % (productionUses[p],
415
+ possibleUses[p],
416
+ p))
417
+
418
+ bestGrammar.clearCache()
419
+
420
+ grammar = bestGrammar.toGrammar()
421
+
422
+ if False and \
423
+ any(productionUses.get(p, 0) < 0.5 for p in grammar.primitives if p.isInvented):
424
+ uselessProductions = [ p for p in grammar.primitives
425
+ if p.isInvented and productionUses.get(p, 0) < 0.5]
426
+ eprint("The following invented primitives are no longer needed, removing them...")
427
+ eprint("\t" + "\t\n".join(map(str, uselessProductions)))
428
+ grammar = grammar.removeProductions(uselessProductions)
429
+
430
+ return grammar, frontiers
dreamcoder/fragmentUtilities.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.type import *
2
+ from dreamcoder.program import *
3
+ from dreamcoder.frontier import *
4
+
5
+ from collections import Counter
6
+
7
+
8
+ class MatchFailure(Exception):
9
+ pass
10
+
11
+
12
+ class Matcher(object):
13
+ def __init__(self, context):
14
+ self.context = context
15
+ self.variableBindings = {}
16
+
17
+ @staticmethod
18
+ def match(context, fragment, expression, numberOfArguments):
19
+ if not mightMatch(fragment, expression):
20
+ raise MatchFailure()
21
+ m = Matcher(context)
22
+ tp = fragment.visit(m, expression, [], numberOfArguments)
23
+ return m.context, tp, m.variableBindings
24
+
25
+ def application(
26
+ self,
27
+ fragment,
28
+ expression,
29
+ environment,
30
+ numberOfArguments):
31
+ '''returns tp of fragment.'''
32
+ if not isinstance(expression, Application):
33
+ raise MatchFailure()
34
+
35
+ ft = fragment.f.visit(
36
+ self,
37
+ expression.f,
38
+ environment,
39
+ numberOfArguments + 1)
40
+ xt = fragment.x.visit(self, expression.x, environment, 0)
41
+
42
+ self.context, returnType = self.context.makeVariable()
43
+ try:
44
+ self.context = self.context.unify(ft, arrow(xt, returnType))
45
+ except UnificationFailure:
46
+ raise MatchFailure()
47
+
48
+ return returnType.apply(self.context)
49
+
50
+ def index(self, fragment, expression, environment, numberOfArguments):
51
+ # This is a bound variable
52
+ surroundingAbstractions = len(environment)
53
+ if fragment.bound(surroundingAbstractions):
54
+ if expression == fragment:
55
+ return environment[fragment.i].apply(self.context)
56
+ else:
57
+ raise MatchFailure()
58
+
59
+ # This is a free variable
60
+ i = fragment.i - surroundingAbstractions
61
+
62
+ # Make sure that it doesn't refer to anything bound by a
63
+ # lambda in the fragment. Otherwise it cannot be safely lifted
64
+ # out of the fragment and preserve semantics
65
+ for fv in expression.freeVariables():
66
+ if fv < len(environment):
67
+ raise MatchFailure()
68
+
69
+ # The value is going to be lifted out of the fragment
70
+ try:
71
+ expression = expression.shift(-surroundingAbstractions)
72
+ except ShiftFailure:
73
+ raise MatchFailure()
74
+
75
+ # Wrap it in the appropriate number of lambda expressions & applications
76
+ # This is because everything has to be in eta-longform
77
+ if numberOfArguments > 0:
78
+ expression = expression.shift(numberOfArguments)
79
+ for j in reversed(range(numberOfArguments)):
80
+ expression = Application(expression, Index(j))
81
+ for _ in range(numberOfArguments):
82
+ expression = Abstraction(expression)
83
+
84
+ # Added to the bindings
85
+ if i in self.variableBindings:
86
+ (tp, binding) = self.variableBindings[i]
87
+ if binding != expression:
88
+ raise MatchFailure()
89
+ else:
90
+ self.context, tp = self.context.makeVariable()
91
+ self.variableBindings[i] = (tp, expression)
92
+ return tp
93
+
94
+ def abstraction(
95
+ self,
96
+ fragment,
97
+ expression,
98
+ environment,
99
+ numberOfArguments):
100
+ if not isinstance(expression, Abstraction):
101
+ raise MatchFailure()
102
+
103
+ self.context, argumentType = self.context.makeVariable()
104
+ returnType = fragment.body.visit(
105
+ self, expression.body, [argumentType] + environment, 0)
106
+
107
+ return arrow(argumentType, returnType)
108
+
109
+ def primitive(self, fragment, expression, environment, numberOfArguments):
110
+ if fragment != expression:
111
+ raise MatchFailure()
112
+ self.context, tp = fragment.tp.instantiate(self.context)
113
+ return tp
114
+
115
+ def invented(self, fragment, expression, environment, numberOfArguments):
116
+ if fragment != expression:
117
+ raise MatchFailure()
118
+ self.context, tp = fragment.tp.instantiate(self.context)
119
+ return tp
120
+
121
+ def fragmentVariable(
122
+ self,
123
+ fragment,
124
+ expression,
125
+ environment,
126
+ numberOfArguments):
127
+ raise Exception(
128
+ 'Deprecated: matching against fragment variables. Convert fragment to canonical form to get rid of fragment variables.')
129
+
130
+
131
+ def mightMatch(f, e, d=0):
132
+ '''Checks whether fragment f might be able to match against expression e'''
133
+ if f.isIndex:
134
+ if f.bound(d):
135
+ return f == e
136
+ return True
137
+ if f.isPrimitive or f.isInvented:
138
+ return f == e
139
+ if f.isAbstraction:
140
+ return e.isAbstraction and mightMatch(f.body, e.body, d + 1)
141
+ if f.isApplication:
142
+ return e.isApplication and mightMatch(
143
+ f.x, e.x, d) and mightMatch(
144
+ f.f, e.f, d)
145
+ assert False
146
+
147
+
148
+ def canonicalFragment(expression):
149
+ '''
150
+ Puts a fragment into a canonical form:
151
+ 1. removes all FragmentVariable's
152
+ 2. renames all free variables based on depth first traversal
153
+ '''
154
+ return expression.visit(CanonicalVisitor(), 0)
155
+
156
+
157
+ class CanonicalVisitor(object):
158
+ def __init__(self):
159
+ self.numberOfAbstractions = 0
160
+ self.mapping = {}
161
+
162
+ def fragmentVariable(self, e, d):
163
+ self.numberOfAbstractions += 1
164
+ return Index(self.numberOfAbstractions + d - 1)
165
+
166
+ def primitive(self, e, d): return e
167
+
168
+ def invented(self, e, d): return e
169
+
170
+ def application(self, e, d):
171
+ return Application(e.f.visit(self, d), e.x.visit(self, d))
172
+
173
+ def abstraction(self, e, d):
174
+ return Abstraction(e.body.visit(self, d + 1))
175
+
176
+ def index(self, e, d):
177
+ if e.bound(d):
178
+ return e
179
+ i = e.i - d
180
+ if i in self.mapping:
181
+ return Index(d + self.mapping[i])
182
+ self.mapping[i] = self.numberOfAbstractions
183
+ self.numberOfAbstractions += 1
184
+ return Index(self.numberOfAbstractions - 1 + d)
185
+
186
+
187
+ def fragmentSize(f, boundVariableCost=0.1, freeVariableCost=0.01):
188
+ freeVariables = 0
189
+ leaves = 0
190
+ boundVariables = 0
191
+ for surroundingAbstractions, e in f.walk():
192
+ if isinstance(e, (Primitive, Invented)):
193
+ leaves += 1
194
+ if isinstance(e, Index):
195
+ if e.bound(surroundingAbstractions):
196
+ boundVariables += 1
197
+ else:
198
+ freeVariables += 1
199
+ assert not isinstance(e, FragmentVariable)
200
+ return leaves + boundVariableCost * \
201
+ boundVariables + freeVariableCost * freeVariables
202
+
203
+
204
+ def primitiveSize(e):
205
+ if e.isInvented:
206
+ e = e.body
207
+ return fragmentSize(e)
208
+
209
+
210
+ def defragment(expression):
211
+ '''Converts a fragment into an invented primitive'''
212
+ if isinstance(expression, (Primitive, Invented)):
213
+ return expression
214
+
215
+ expression = canonicalFragment(expression)
216
+
217
+ for _ in range(expression.numberOfFreeVariables):
218
+ expression = Abstraction(expression)
219
+
220
+ return Invented(expression)
221
+
222
+
223
+ class RewriteFragments(object):
224
+ def __init__(self, fragment):
225
+ self.fragment = fragment
226
+ self.concrete = defragment(fragment)
227
+
228
+ def tryRewrite(self, e, numberOfArguments):
229
+ try:
230
+ context, t, bindings = Matcher.match(
231
+ Context.EMPTY, self.fragment, e, numberOfArguments)
232
+ except MatchFailure:
233
+ return None
234
+
235
+ assert frozenset(bindings.keys()) == frozenset(range(len(bindings))),\
236
+ "Perhaps the fragment is not in canonical form?"
237
+ e = self.concrete
238
+ for j in range(len(bindings) - 1, -1, -1):
239
+ _, b = bindings[j]
240
+ e = Application(e, b)
241
+ return e
242
+
243
+ def application(self, e, numberOfArguments):
244
+ e = Application(e.f.visit(self, numberOfArguments + 1),
245
+ e.x.visit(self, 0))
246
+ return self.tryRewrite(e, numberOfArguments) or e
247
+
248
+ def index(self, e, numberOfArguments): return e
249
+
250
+ def invented(self, e, numberOfArguments): return e
251
+
252
+ def primitive(self, e, numberOfArguments): return e
253
+
254
+ def abstraction(self, e, numberOfArguments):
255
+ e = Abstraction(e.body.visit(self, 0))
256
+ return self.tryRewrite(e, numberOfArguments) or e
257
+
258
+ def rewrite(self, e): return e.visit(self, 0)
259
+
260
+ @staticmethod
261
+ def rewriteFrontier(frontier, fragment):
262
+ worker = RewriteFragments(fragment)
263
+ return Frontier([FrontierEntry(program=worker.rewrite(e.program),
264
+ logLikelihood=e.logLikelihood,
265
+ logPrior=e.logPrior,
266
+ logPosterior=e.logPosterior)
267
+ for e in frontier],
268
+ task=frontier.task)
269
+
270
+
271
+ def proposeFragmentsFromFragment(f):
272
+ '''Abstracts out repeated structure within a single fragment'''
273
+ yield f
274
+ freeVariables = f.numberOfFreeVariables
275
+ closedSubtrees = Counter(
276
+ subtree for _,
277
+ subtree in f.walk() if not isinstance(
278
+ subtree,
279
+ Index) and subtree.closed)
280
+ del closedSubtrees[f]
281
+ for subtree, freq in closedSubtrees.items():
282
+ if freq < 2:
283
+ continue
284
+ yield canonicalFragment(f.substitute(subtree, Index(freeVariables)))
285
+
286
+
287
+ def nontrivial(f):
288
+ if not isinstance(f, Application):
289
+ return False
290
+ # Curry
291
+ if isinstance(f.x, FragmentVariable):
292
+ return False
293
+ if isinstance(f.x, Index):
294
+ # Make sure that the index is used somewhere else
295
+ if not any(
296
+ isinstance(
297
+ child,
298
+ Index) and child.i -
299
+ surroundingAbstractions == f.x.i for surroundingAbstractions,
300
+ child in f.f.walk()):
301
+ return False
302
+
303
+ numberOfHoles = 0
304
+ numberOfVariables = 0
305
+ numberOfPrimitives = 0
306
+ for surroundingAbstractions, child in f.walk():
307
+ if isinstance(child, (Primitive, Invented)):
308
+ numberOfPrimitives += 1
309
+ if isinstance(child, FragmentVariable):
310
+ numberOfHoles += 1
311
+ if isinstance(child, Index) and child.free(surroundingAbstractions):
312
+ numberOfVariables += 1
313
+ #eprint("Fragment %s has %d calls and %d variables and %d primitives"%(f,numberOfHoles,numberOfVariables,numberOfPrimitives))
314
+
315
+ return numberOfPrimitives + 0.5 * \
316
+ (numberOfHoles + numberOfVariables) > 1.5 and numberOfPrimitives >= 1
317
+
318
+
319
+ def violatesLaziness(fragment):
320
+ """
321
+ conditionals are lazy on the second and third arguments. this
322
+ invariant must be maintained by learned fragments.
323
+ """
324
+ for surroundingAbstractions, child in fragment.walkUncurried():
325
+ if not child.isApplication:
326
+ continue
327
+ f, xs = child.applicationParse()
328
+ if not (f.isPrimitive and f.name == "if"):
329
+ continue
330
+
331
+ # curried conditionals always violate laziness
332
+ if len(xs) != 3:
333
+ return True
334
+
335
+ # yes/no branches
336
+ y = xs[1]
337
+ n = xs[2]
338
+
339
+ return \
340
+ any(yc.isIndex and yc.i >= yd
341
+ for yd, yc in y.walk(surroundingAbstractions)) or \
342
+ any(nc.isIndex and nc.i >= nd
343
+ for nd, nc in n.walk(surroundingAbstractions))
344
+
345
+ return False
346
+
347
+
348
+ def proposeFragmentsFromProgram(p, arity):
349
+
350
+ def fragment(expression, a, toplevel=True):
351
+ """Generates fragments that unify with expression"""
352
+
353
+ if a == 1:
354
+ yield FragmentVariable.single
355
+ if a == 0:
356
+ yield expression
357
+ return
358
+
359
+ if isinstance(expression, Abstraction):
360
+ # Symmetry breaking: (\x \y \z ... f(x,y,z,...)) defragments to be
361
+ # the same as f(x,y,z,...)
362
+ if not toplevel:
363
+ for b in fragment(expression.body, a, toplevel=False):
364
+ yield Abstraction(b)
365
+ elif isinstance(expression, Application):
366
+ for fa in range(a + 1):
367
+ for f in fragment(expression.f, fa, toplevel=False):
368
+ for x in fragment(expression.x, a - fa, toplevel=False):
369
+ yield Application(f, x)
370
+ else:
371
+ assert isinstance(expression, (Invented, Primitive, Index))
372
+
373
+ def fragments(expression, a):
374
+ """Generates fragments that unify with subexpressions of expression"""
375
+
376
+ yield from fragment(expression, a)
377
+ if isinstance(expression, Application):
378
+ curry = True
379
+ if curry:
380
+ yield from fragments(expression.f, a)
381
+ yield from fragments(expression.x, a)
382
+ else:
383
+ # Pretend that it is not curried
384
+ function, arguments = expression.applicationParse()
385
+ yield from fragments(function, a)
386
+ for argument in arguments:
387
+ yield from fragments(argument, a)
388
+ elif isinstance(expression, Abstraction):
389
+ yield from fragments(expression.body, a)
390
+ else:
391
+ assert isinstance(expression, (Invented, Primitive, Index))
392
+
393
+ return {canonicalFragment(f) for b in range(arity + 1)
394
+ for f in fragments(p, b) if nontrivial(f)}
395
+
396
+
397
+ def proposeFragmentsFromFrontiers(frontiers, a, CPUs=1):
398
+ fragmentsFromEachFrontier = parallelMap(
399
+ CPUs, lambda frontier: {
400
+ fp for entry in frontier.entries for f in proposeFragmentsFromProgram(
401
+ entry.program, a) for fp in proposeFragmentsFromFragment(f)}, frontiers)
402
+ allFragments = Counter(f for frontierFragments in fragmentsFromEachFrontier
403
+ for f in frontierFragments)
404
+ return [fragment for fragment, frequency in allFragments.items()
405
+ if frequency >= 2 and fragment.wellTyped() and nontrivial(fragment)]
dreamcoder/frontier.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.utilities import *
2
+ from dreamcoder.program import *
3
+ from dreamcoder.task import Task
4
+
5
+
6
+ class FrontierEntry(object):
7
+ def __init__(
8
+ self,
9
+ program,
10
+ _=None,
11
+ logPrior=None,
12
+ logLikelihood=None,
13
+ logPosterior=None):
14
+ self.logPosterior = logPrior + logLikelihood if logPosterior is None else logPosterior
15
+ self.program = program
16
+ self.logPrior = logPrior
17
+ self.logLikelihood = logLikelihood
18
+
19
+ def __repr__(self):
20
+ return "FrontierEntry(program={self.program}, logPrior={self.logPrior}, logLikelihood={self.logLikelihood}".format(
21
+ self=self)
22
+
23
+ def strip_primitive_values(self):
24
+ return FrontierEntry(program=strip_primitive_values(self.program),
25
+ logPrior=self.logPrior,
26
+ logPosterior=self.logPosterior,
27
+ logLikelihood=self.logLikelihood)
28
+ def unstrip_primitive_values(self):
29
+ return FrontierEntry(program=unstrip_primitive_values(self.program),
30
+ logPrior=self.logPrior,
31
+ logPosterior=self.logPosterior,
32
+ logLikelihood=self.logLikelihood)
33
+
34
+
35
+ class Frontier(object):
36
+ def __init__(self, frontier, task):
37
+ self.entries = frontier
38
+ self.task = task
39
+
40
+ def __repr__(
41
+ self): return "Frontier(entries={self.entries}, task={self.task})".format(self=self)
42
+
43
+ def __iter__(self): return iter(self.entries)
44
+
45
+ def __len__(self): return len(self.entries)
46
+
47
+ def json(self):
48
+ return {"request": self.task.request.json(),
49
+ "task": str(self.task),
50
+ "programs": [{"program": str(e.program),
51
+ "logLikelihood": e.logLikelihood}
52
+ for e in self ]}
53
+
54
+ def strip_primitive_values(self):
55
+ return Frontier([e.strip_primitive_values() for e in self.entries ],
56
+ self.task)
57
+ def unstrip_primitive_values(self):
58
+ return Frontier([e.unstrip_primitive_values() for e in self.entries ],
59
+ self.task)
60
+
61
+ DUMMYFRONTIERCOUNTER = 0
62
+
63
+ @staticmethod
64
+ def dummy(program, logLikelihood=0., logPrior=0., tp=None):
65
+ """Creates a dummy frontier containing just this program"""
66
+ if not tp:
67
+ tp = program.infer().negateVariables()
68
+
69
+ t = Task(
70
+ "<dummy %d: %s>" %
71
+ (Frontier.DUMMYFRONTIERCOUNTER,
72
+ str(program)),
73
+ tp,
74
+ [])
75
+ f = Frontier([FrontierEntry(program=program,
76
+ logLikelihood=logLikelihood,
77
+ logPrior=logPrior)],
78
+ task=t)
79
+ Frontier.DUMMYFRONTIERCOUNTER += 1
80
+ return f
81
+
82
+ def marginalLikelihood(self):
83
+ return lse([e.logPrior + e.logLikelihood for e in self])
84
+
85
+ def temperature(self,T):
86
+ """Divides prior by T"""
87
+ return Frontier([ FrontierEntry(program=e.program,
88
+ logPrior=e.logPrior/T,
89
+ logLikelihood=e.logLikelihood)
90
+ for e in self],
91
+ task=self.task)
92
+
93
+
94
+ def normalize(self):
95
+ z = self.marginalLikelihood()
96
+ newEntries = [
97
+ FrontierEntry(
98
+ program=e.program,
99
+ logPrior=e.logPrior,
100
+ logLikelihood=e.logLikelihood,
101
+ logPosterior=e.logPrior +
102
+ e.logLikelihood -
103
+ z) for e in self]
104
+ newEntries.sort(key=lambda e: e.logPosterior, reverse=True)
105
+ return Frontier(newEntries,
106
+ self.task)
107
+
108
+ def expectedProductionUses(self, g):
109
+ """Returns a vector of the expected number of times each production was used"""
110
+ import numpy as np
111
+
112
+ this = g.rescoreFrontier(self).normalize()
113
+ ps = list(sorted(g.primitives, key=str))
114
+ features = np.zeros(len(ps))
115
+
116
+ for j, p in enumerate(ps):
117
+ for e in this:
118
+ w = math.exp(e.logPosterior)
119
+ features[j] += w * sum(child == p
120
+ for _, child in e.program.walk() )
121
+ if not p.isInvented: features[j] *= 0.3
122
+ return features
123
+
124
+
125
+ def removeZeroLikelihood(self):
126
+ self.entries = [
127
+ e for e in self.entries if e.logLikelihood != float('-inf')]
128
+ return self
129
+
130
+ def topK(self, k):
131
+ if k == 0: return Frontier([], self.task)
132
+ if k < 0: return self
133
+ newEntries = sorted(self.entries,
134
+ key=lambda e: (-e.logPosterior, str(e.program)))
135
+ return Frontier(newEntries[:k], self.task)
136
+
137
+ def sample(self):
138
+ """Samples an entry from a frontier"""
139
+ return sampleLogDistribution([(e.logLikelihood + e.logPrior, e)
140
+ for e in self])
141
+
142
+ @property
143
+ def bestPosterior(self):
144
+ return min(self.entries,
145
+ key=lambda e: (-e.logPosterior, str(e.program)))
146
+
147
+ def replaceWithSupervised(self, g):
148
+ assert self.task.supervision is not None
149
+ return g.rescoreFrontier(Frontier([FrontierEntry(self.task.supervision,
150
+ logLikelihood=0., logPrior=0.)],
151
+ task=self.task))
152
+
153
+ @property
154
+ def bestll(self):
155
+ best = max(self.entries,
156
+ key=lambda e: e.logLikelihood)
157
+ return best.logLikelihood
158
+
159
+
160
+ @property
161
+ def empty(self): return self.entries == []
162
+
163
+ @staticmethod
164
+ def makeEmpty(task):
165
+ return Frontier([], task=task)
166
+
167
+ def summarize(self):
168
+ if self.empty:
169
+ return "MISS " + self.task.name
170
+ best = self.bestPosterior
171
+ return "HIT %s w/ %s ; log prior = %f ; log likelihood = %f" % (
172
+ self.task.name, best.program, best.logPrior, best.logLikelihood)
173
+
174
+ def summarizeFull(self):
175
+ if self.empty:
176
+ return "MISS " + self.task.name
177
+ return "\n".join([self.task.name] +
178
+ ["%f\t%s" % (e.logPosterior, e.program)
179
+ for e in self.normalize()])
180
+
181
+ @staticmethod
182
+ def describe(frontiers):
183
+ numberOfHits = sum(not f.empty for f in frontiers)
184
+ if numberOfHits > 0:
185
+ averageLikelihood = sum(
186
+ f.bestPosterior.logPrior for f in frontiers if not f.empty) / numberOfHits
187
+ else:
188
+ averageLikelihood = 0
189
+ return "\n".join([f.summarize() for f in frontiers] +
190
+ ["Hits %d/%d tasks" % (numberOfHits, len(frontiers))] +
191
+ ["Average description length of a program solving a task: %f nats" % (-averageLikelihood)])
192
+
193
+ def combine(self, other, tolerance=0.01):
194
+ '''Takes the union of the programs in each of the frontiers'''
195
+ assert self.task == other.task
196
+
197
+ foundDifference = False
198
+
199
+ x = {e.program: e for e in self}
200
+ y = {e.program: e for e in other}
201
+ programs = set(x.keys()) | set(y.keys())
202
+ union = []
203
+ for p in programs:
204
+ if p in x:
205
+ e1 = x[p]
206
+ if p in y:
207
+ e2 = y[p]
208
+ if abs(e1.logPrior - e2.logPrior) > tolerance:
209
+ eprint(
210
+ "WARNING: Log priors differed during frontier combining: %f vs %f" %
211
+ (e1.logPrior, e2.logPrior))
212
+ eprint("WARNING: \tThe program is", p)
213
+ eprint()
214
+ if abs(e1.logLikelihood - e2.logLikelihood) > tolerance:
215
+ foundDifference = True
216
+ eprint(
217
+ "WARNING: Log likelihoods deferred for %s: %f & %f" %
218
+ (p, e1.logLikelihood, e2.logLikelihood))
219
+ if hasattr(self.task, 'BIC'):
220
+ eprint("\t%d examples, BIC=%f, parameterPenalty=%f, n parameters=%d, correct likelihood=%f" %
221
+ (len(self.task.examples),
222
+ self.task.BIC,
223
+ self.task.BIC * math.log(len(self.task.examples)),
224
+ substringOccurrences("REAL", str(p)),
225
+ substringOccurrences("REAL", str(p)) * self.task.BIC * math.log(len(self.task.examples))))
226
+ e1.logLikelihood = - \
227
+ substringOccurrences("REAL", str(p)) * self.task.BIC * math.log(len(self.task.examples))
228
+ e2.logLikelihood = e1.logLikelihood
229
+
230
+ e1 = FrontierEntry(
231
+ program=e1.program,
232
+ logLikelihood=(
233
+ e1.logLikelihood +
234
+ e2.logLikelihood) /
235
+ 2,
236
+ logPrior=e1.logPrior)
237
+ else:
238
+ e1 = y[p]
239
+ union.append(e1)
240
+
241
+ if foundDifference:
242
+ eprint(
243
+ "WARNING: Log likelihoods differed for the same program on the task %s.\n" %
244
+ (self.task.name),
245
+ "\tThis is acceptable only if the likelihood model is stochastic. Took the geometric mean of the likelihoods.")
246
+
247
+ return Frontier(union, self.task)
dreamcoder/grammar.py ADDED
@@ -0,0 +1,1308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ from dreamcoder.frontier import *
4
+ from dreamcoder.program import *
5
+ from dreamcoder.type import *
6
+ from dreamcoder.utilities import *
7
+
8
+ import time
9
+
10
+ class GrammarFailure(Exception):
11
+ pass
12
+
13
+ class SketchEnumerationFailure(Exception):
14
+ pass
15
+
16
+ class NoCandidates(Exception):
17
+ pass
18
+
19
+
20
+ class Grammar(object):
21
+ def __init__(self, logVariable, productions, continuationType=None):
22
+ self.logVariable = logVariable
23
+ self.productions = productions
24
+
25
+ self.continuationType = continuationType
26
+
27
+ self.expression2likelihood = dict((p, l) for l, _, p in productions)
28
+ self.expression2likelihood[Index(0)] = self.logVariable
29
+
30
+ def randomWeights(self, r):
31
+ """returns a new grammar with random weights drawn from r. calls `r` w/ old weight"""
32
+ return Grammar(logVariable=r(self.logVariable),
33
+ productions=[(r(l),t,p)
34
+ for l,t,p in self.productions ],
35
+ continuationType=self.continuationType)
36
+
37
+ def strip_primitive_values(self):
38
+ return Grammar(logVariable=self.logVariable,
39
+ productions=[(l,t,strip_primitive_values(p))
40
+ for l,t,p in self.productions ],
41
+ continuationType=self.continuationType)
42
+
43
+ def unstrip_primitive_values(self):
44
+ return Grammar(logVariable=self.logVariable,
45
+ productions=[(l,t,unstrip_primitive_values(p))
46
+ for l,t,p in self.productions ],
47
+ continuationType=self.continuationType)
48
+
49
+ def __setstate__(self, state):
50
+ """
51
+ Legacy support for loading grammar objects without the imperative type filled in
52
+ """
53
+ assert 'logVariable' in state
54
+ assert 'productions' in state
55
+ if 'continuationType' in state:
56
+ continuationType = state['continuationType']
57
+ else:
58
+ if any( 'turtle' in str(t) for l,t,p in state['productions'] ):
59
+ continuationType = baseType("turtle")
60
+ elif any( 'tower' in str(t) for l,t,p in state['productions'] ):
61
+ continuationType = baseType("tower")
62
+ else:
63
+ continuationType = None
64
+
65
+ self.__init__(state['logVariable'], state['productions'], continuationType=continuationType)
66
+
67
+ @staticmethod
68
+ def fromProductions(productions, logVariable=0.0, continuationType=None):
69
+ """Make a grammar from primitives and their relative logpriors."""
70
+ return Grammar(logVariable, [(l, p.infer(), p)
71
+ for l, p in productions],
72
+ continuationType=continuationType)
73
+
74
+ @staticmethod
75
+ def uniform(primitives, continuationType=None):
76
+ return Grammar(0.0, [(0.0, p.infer(), p) for p in primitives], continuationType=continuationType)
77
+
78
+ def __len__(self): return len(self.productions)
79
+
80
+ def __str__(self):
81
+ def productionKey(xxx_todo_changeme):
82
+ (l, t, p) = xxx_todo_changeme
83
+ return not isinstance(p, Primitive), l is not None and -l
84
+ if self.continuationType is not None:
85
+ lines = ["continuation : %s"%self.continuationType]
86
+ else:
87
+ lines = []
88
+ lines += ["%f\tt0\t$_" % self.logVariable]
89
+ for l, t, p in sorted(self.productions, key=productionKey):
90
+ if l is not None:
91
+ l = "%f\t%s\t%s" % (l, t, p)
92
+ else:
93
+ l = "-Inf\t%s\t%s" % (t, p)
94
+ if not t.isArrow() and isinstance(p, Invented):
95
+ try:
96
+ l += "\teval = %s" % (p.evaluate([]))
97
+ except BaseException:
98
+ pass
99
+
100
+ lines.append(l)
101
+ return "\n".join(lines)
102
+
103
+ def json(self):
104
+ j = {"logVariable": self.logVariable,
105
+ "productions": [{"expression": str(p), "logProbability": l}
106
+ for l, _, p in self.productions]}
107
+ if self.continuationType is not None:
108
+ j["continuationType"] = self.continuationType.json()
109
+ return j
110
+
111
+ def _immutable_code(self): return self.logVariable, tuple(self.productions)
112
+
113
+ def __eq__(self, o): return self._immutable_code() == o._immutable_code()
114
+
115
+ def __ne__(self, o): return not (self == o)
116
+
117
+ def __hash__(self): return hash(self._immutable_code())
118
+
119
+ @property
120
+ def primitives(self):
121
+ return [p for _, _, p in self.productions]
122
+
123
+ def removeProductions(self, ps):
124
+ return Grammar(
125
+ self.logVariable, [
126
+ (l, t, p) for (
127
+ l, t, p) in self.productions if p not in ps],
128
+ continuationType=self.continuationType)
129
+
130
+ def buildCandidates(self, request, context, environment,
131
+ # Should the log probabilities be normalized?
132
+ normalize=True,
133
+ # Should be returned a table mapping primitives to
134
+ # their candidate entry?
135
+ returnTable=False,
136
+ # Should we return probabilities vs log probabilities?
137
+ returnProbabilities=False,
138
+ # Must be a leaf (have no arguments)?
139
+ mustBeLeaf=False):
140
+ """Primitives that are candidates for being used given a requested type
141
+ If returnTable is false (default): returns [((log)likelihood, tp, primitive, context)]
142
+ if returntable is true: returns {primitive: ((log)likelihood, tp, context)}"""
143
+ if returnProbabilities:
144
+ assert normalize
145
+
146
+ candidates = []
147
+ variableCandidates = []
148
+ for l, t, p in self.productions:
149
+ try:
150
+ newContext, t = t.instantiate(context)
151
+ newContext = newContext.unify(t.returns(), request)
152
+ t = t.apply(newContext)
153
+ if mustBeLeaf and t.isArrow():
154
+ continue
155
+ candidates.append((l, t, p, newContext))
156
+ except UnificationFailure:
157
+ continue
158
+ for j, t in enumerate(environment):
159
+ try:
160
+ newContext = context.unify(t.returns(), request)
161
+ t = t.apply(newContext)
162
+ if mustBeLeaf and t.isArrow():
163
+ continue
164
+ variableCandidates.append((t, Index(j), newContext))
165
+ except UnificationFailure:
166
+ continue
167
+
168
+ if self.continuationType == request:
169
+ terminalIndices = [v.i for t,v,k in variableCandidates if not t.isArrow()]
170
+ if terminalIndices:
171
+ smallestIndex = Index(min(terminalIndices))
172
+ variableCandidates = [(t,v,k) for t,v,k in variableCandidates
173
+ if t.isArrow() or v == smallestIndex]
174
+
175
+ candidates += [(self.logVariable - log(len(variableCandidates)), t, p, k)
176
+ for t, p, k in variableCandidates]
177
+ if candidates == []:
178
+ raise NoCandidates()
179
+ #eprint("candidates inside buildCandidates before norm:")
180
+ #eprint(candidates)
181
+
182
+ if normalize:
183
+ z = lse([l for l, t, p, k in candidates])
184
+ if returnProbabilities:
185
+ candidates = [(exp(l - z), t, p, k)
186
+ for l, t, p, k in candidates]
187
+ else:
188
+ candidates = [(l - z, t, p, k) for l, t, p, k in candidates]
189
+
190
+ #eprint("candidates inside buildCandidates after norm:")
191
+ #eprint(candidates)
192
+
193
+ if returnTable:
194
+ return {p: (l, t, k) for l, t, p, k in candidates}
195
+ else:
196
+ return candidates
197
+
198
+
199
+ def sample(self, request, maximumDepth=6, maxAttempts=None):
200
+ attempts = 0
201
+
202
+ while True:
203
+ try:
204
+ _, e = self._sample(
205
+ request, Context.EMPTY, [], maximumDepth=maximumDepth)
206
+ return e
207
+ except NoCandidates:
208
+ if maxAttempts is not None:
209
+ attempts += 1
210
+ if attempts > maxAttempts:
211
+ return None
212
+ continue
213
+
214
+ def _sample(self, request, context, environment, maximumDepth):
215
+ if request.isArrow():
216
+ context, expression = self._sample(
217
+ request.arguments[1], context, [
218
+ request.arguments[0]] + environment, maximumDepth)
219
+ return context, Abstraction(expression)
220
+
221
+ candidates = self.buildCandidates(request, context, environment,
222
+ normalize=True,
223
+ returnProbabilities=True,
224
+ # Force it to terminate in a
225
+ # leaf; a primitive with no
226
+ # function arguments
227
+ mustBeLeaf=maximumDepth <= 1)
228
+ #eprint("candidates:")
229
+ #eprint(candidates)
230
+ newType, chosenPrimitive, context = sampleDistribution(candidates)
231
+
232
+ # Sample the arguments
233
+ xs = newType.functionArguments()
234
+ returnValue = chosenPrimitive
235
+
236
+ for x in xs:
237
+ x = x.apply(context)
238
+ context, x = self._sample(x, context, environment, maximumDepth - 1)
239
+ returnValue = Application(returnValue, x)
240
+
241
+ return context, returnValue
242
+
243
+ def likelihoodSummary(self, context, environment, request, expression, silent=False):
244
+ if request.isArrow():
245
+ if not isinstance(expression, Abstraction):
246
+ if not silent:
247
+ eprint("Request is an arrow but I got", expression)
248
+ return context, None
249
+ return self.likelihoodSummary(context,
250
+ [request.arguments[0]] + environment,
251
+ request.arguments[1],
252
+ expression.body,
253
+ silent=silent)
254
+ # Build the candidates
255
+ candidates = self.buildCandidates(request, context, environment,
256
+ normalize=False,
257
+ returnTable=True)
258
+
259
+ # A list of everything that would have been possible to use here
260
+ possibles = [p for p in candidates.keys() if not p.isIndex]
261
+ numberOfVariables = sum(p.isIndex for p in candidates.keys())
262
+ if numberOfVariables > 0:
263
+ possibles += [Index(0)]
264
+
265
+ f, xs = expression.applicationParse()
266
+
267
+ if f not in candidates:
268
+ if self.continuationType is not None and f.isIndex:
269
+ ls = LikelihoodSummary()
270
+ ls.constant = NEGATIVEINFINITY
271
+ return ls
272
+
273
+ if not silent:
274
+ eprint(f, "Not in candidates")
275
+ eprint("Candidates is", candidates)
276
+ #eprint("grammar:", grammar.productions)
277
+ eprint("request is", request)
278
+ eprint("xs", xs)
279
+ eprint("environment", environment)
280
+ assert False
281
+ return context, None
282
+
283
+ thisSummary = LikelihoodSummary()
284
+ thisSummary.record(f, possibles,
285
+ constant= -math.log(numberOfVariables) if f.isIndex else 0)
286
+
287
+ _, tp, context = candidates[f]
288
+ argumentTypes = tp.functionArguments()
289
+ if len(xs) != len(argumentTypes):
290
+ eprint("PANIC: not enough arguments for the type")
291
+ eprint("request", request)
292
+ eprint("tp", tp)
293
+ eprint("expression", expression)
294
+ eprint("xs", xs)
295
+ eprint("argumentTypes", argumentTypes)
296
+ # This should absolutely never occur
297
+ raise GrammarFailure((context, environment, request, expression))
298
+
299
+ for argumentType, argument in zip(argumentTypes, xs):
300
+ argumentType = argumentType.apply(context)
301
+ context, newSummary = self.likelihoodSummary(
302
+ context, environment, argumentType, argument, silent=silent)
303
+ if newSummary is None:
304
+ return context, None
305
+ thisSummary.join(newSummary)
306
+
307
+ return context, thisSummary
308
+
309
+ def bestFirstEnumeration(self, request):
310
+ from heapq import heappush, heappop
311
+
312
+ pq = []
313
+
314
+ def choices(parentCost, xs):
315
+ for c, x in xs:
316
+ heappush(pq, (parentCost + c, x))
317
+
318
+ def g(parentCost, request, _=None,
319
+ context=None, environment=[],
320
+ k=None):
321
+ """
322
+ k is a continuation.
323
+ k: Expects to be called with MDL, context, expression.
324
+ """
325
+
326
+ assert k is not None
327
+ if context is None:
328
+ context = Context.EMPTY
329
+
330
+ if request.isArrow():
331
+ g(parentCost,
332
+ request.arguments[1],
333
+ context=context,
334
+ environment=[request.arguments[0]] + environment,
335
+ k=lambda MDL,
336
+ newContext,
337
+ p: k(MDL,
338
+ newContext,
339
+ Abstraction(p)))
340
+ else:
341
+ candidates = self.buildCandidates(request,
342
+ context,
343
+ environment,
344
+ normalize=True,
345
+ returnProbabilities=False,
346
+ returnTable=True)
347
+ choices(parentCost,
348
+ [(-f_ll_tp_newContext[1][0],
349
+ lambda: ga(parentCost - f_ll_tp_newContext[1][0],
350
+ f_ll_tp_newContext[0],
351
+ f_ll_tp_newContext[1][1].functionArguments(),
352
+ context=f_ll_tp_newContext[1][2],
353
+ environment=environment,
354
+ k=k)) for f_ll_tp_newContext in iter(candidates.items())])
355
+
356
+ def ga(costSoFar, f, argumentTypes, _=None,
357
+ context=None, environment=None,
358
+ k=None):
359
+ if argumentTypes == []:
360
+ k(costSoFar, context, f)
361
+ else:
362
+ t1 = argumentTypes[0].apply(context)
363
+ g(costSoFar, t1, context=context, environment=environment,
364
+ k=lambda newCost, newContext, argument:
365
+ ga(newCost, Application(f, argument), argumentTypes[1:],
366
+ context=newContext, environment=environment,
367
+ k=k))
368
+
369
+ def receiveResult(MDL, _, expression):
370
+ heappush(pq, (MDL, expression))
371
+
372
+ g(0., request, context=Context.EMPTY, environment=[], k=receiveResult)
373
+ frontier = []
374
+ while len(frontier) < 10**3:
375
+ MDL, action = heappop(pq)
376
+ if isinstance(action, Program):
377
+ expression = action
378
+ frontier.append(expression)
379
+ #eprint("Enumerated program",expression,-MDL,self.closedLogLikelihood(request, expression))
380
+ else:
381
+ action()
382
+
383
+ def closedLikelihoodSummary(self, request, expression, silent=False):
384
+ try:
385
+ context, summary = self.likelihoodSummary(Context.EMPTY, [], request, expression, silent=silent)
386
+ except GrammarFailure as e:
387
+ failureExport = 'failures/grammarFailure%s.pickle' % (
388
+ time.time() + getPID())
389
+ eprint("PANIC: Grammar failure, exporting to ", failureExport)
390
+ with open(failureExport, 'wb') as handle:
391
+ pickle.dump((e, self, request, expression), handle)
392
+ assert False
393
+
394
+ return summary
395
+
396
+ def logLikelihood(self, request, expression):
397
+ summary = self.closedLikelihoodSummary(request, expression)
398
+ if summary is None:
399
+ eprint(
400
+ "FATAL: program [ %s ] does not have a likelihood summary." %
401
+ expression, "r = ", request, "\n", self)
402
+ assert False
403
+ return summary.logLikelihood(self)
404
+
405
+ def rescoreFrontier(self, frontier):
406
+ return Frontier([FrontierEntry(e.program,
407
+ logPrior=self.logLikelihood(frontier.task.request, e.program),
408
+ logLikelihood=e.logLikelihood)
409
+ for e in frontier],
410
+ frontier.task)
411
+
412
+ def productionUses(self, frontiers):
413
+ """Returns the expected number of times that each production was used. {production: expectedUses}"""
414
+ frontiers = [self.rescoreFrontier(f).normalize()
415
+ for f in frontiers if not f.empty]
416
+ uses = {p: 0. for p in self.primitives}
417
+ for f in frontiers:
418
+ for e in f:
419
+ summary = self.closedLikelihoodSummary(f.task.request,
420
+ e.program)
421
+ for p, u in summary.uses:
422
+ uses[p] += u * math.exp(e.logPosterior)
423
+ return uses
424
+
425
+ def insideOutside(self, frontiers, pseudoCounts, iterations=1):
426
+ # Replace programs with (likelihood summary, uses)
427
+ frontiers = [ Frontier([ FrontierEntry((summary, summary.toUses()),
428
+ logPrior=summary.logLikelihood(self),
429
+ logLikelihood=e.logLikelihood)
430
+ for e in f
431
+ for summary in [self.closedLikelihoodSummary(f.task.request, e.program)] ],
432
+ task=f.task)
433
+ for f in frontiers ]
434
+
435
+ g = self
436
+ for i in range(iterations):
437
+ u = Uses()
438
+ for f in frontiers:
439
+ f = f.normalize()
440
+ for e in f:
441
+ _, eu = e.program
442
+ u += math.exp(e.logPosterior) * eu
443
+
444
+ lv = math.log(u.actualVariables + pseudoCounts) - \
445
+ math.log(u.possibleVariables + pseudoCounts)
446
+ g = Grammar(lv,
447
+ [ (math.log(u.actualUses.get(p,0.) + pseudoCounts) - \
448
+ math.log(u.possibleUses.get(p,0.) + pseudoCounts),
449
+ t,p)
450
+ for _,t,p in g.productions ],
451
+ continuationType=self.continuationType)
452
+ if i < iterations - 1:
453
+ frontiers = [Frontier([ FrontierEntry((summary, uses),
454
+ logPrior=summary.logLikelihood(g),
455
+ logLikelihood=e.logLikelihood)
456
+ for e in f
457
+ for (summary, uses) in [e.program] ],
458
+ task=f.task)
459
+ for f in frontiers ]
460
+ return g
461
+
462
+ def frontierMDL(self, frontier):
463
+ return max( e.logLikelihood + self.logLikelihood(frontier.task.request, e.program)
464
+ for e in frontier )
465
+
466
+
467
+ def enumeration(self,context,environment,request,upperBound,
468
+ maximumDepth=20,
469
+ lowerBound=0.):
470
+ '''Enumerates all programs whose MDL satisfies: lowerBound <= MDL < upperBound'''
471
+ if upperBound < 0 or maximumDepth == 1:
472
+ return
473
+
474
+ if request.isArrow():
475
+ v = request.arguments[0]
476
+ for l, newContext, b in self.enumeration(context, [v] + environment,
477
+ request.arguments[1],
478
+ upperBound=upperBound,
479
+ lowerBound=lowerBound,
480
+ maximumDepth=maximumDepth):
481
+ yield l, newContext, Abstraction(b)
482
+
483
+ else:
484
+ candidates = self.buildCandidates(request, context, environment,
485
+ normalize=True)
486
+
487
+ for l, t, p, newContext in candidates:
488
+ mdl = -l
489
+ if not (mdl < upperBound):
490
+ continue
491
+
492
+ xs = t.functionArguments()
493
+ for aL, aK, application in\
494
+ self.enumerateApplication(newContext, environment, p, xs,
495
+ upperBound=upperBound + l,
496
+ lowerBound=lowerBound + l,
497
+ maximumDepth=maximumDepth - 1):
498
+ yield aL + l, aK, application
499
+
500
+ def enumerateApplication(self, context, environment,
501
+ function, argumentRequests,
502
+ # Upper bound on the description length of all of
503
+ # the arguments
504
+ upperBound,
505
+ # Lower bound on the description length of all of
506
+ # the arguments
507
+ lowerBound=0.,
508
+ maximumDepth=20,
509
+ originalFunction=None,
510
+ argumentIndex=0):
511
+ if upperBound < 0. or maximumDepth == 1:
512
+ return
513
+ if originalFunction is None:
514
+ originalFunction = function
515
+
516
+ if argumentRequests == []:
517
+ if lowerBound <= 0. and 0. < upperBound:
518
+ yield 0., context, function
519
+ else:
520
+ return
521
+ else:
522
+ argRequest = argumentRequests[0].apply(context)
523
+ laterRequests = argumentRequests[1:]
524
+ for argL, newContext, arg in self.enumeration(context, environment, argRequest,
525
+ upperBound=upperBound,
526
+ lowerBound=0.,
527
+ maximumDepth=maximumDepth):
528
+ if violatesSymmetry(originalFunction, arg, argumentIndex):
529
+ continue
530
+
531
+ newFunction = Application(function, arg)
532
+ for resultL, resultK, result in self.enumerateApplication(newContext, environment, newFunction,
533
+ laterRequests,
534
+ upperBound=upperBound + argL,
535
+ lowerBound=lowerBound + argL,
536
+ maximumDepth=maximumDepth,
537
+ originalFunction=originalFunction,
538
+ argumentIndex=argumentIndex + 1):
539
+ yield resultL + argL, resultK, result
540
+
541
+ def sketchEnumeration(self,context,environment,request,sk,upperBound,
542
+ maximumDepth=20,
543
+ lowerBound=0.):
544
+ '''Enumerates all sketch instantiations whose MDL satisfies: lowerBound <= MDL < upperBound'''
545
+ if upperBound < 0. or maximumDepth == 1:
546
+ return
547
+
548
+ if sk.isHole:
549
+ yield from self.enumeration(context, environment, request, upperBound,
550
+ maximumDepth=maximumDepth,
551
+ lowerBound=lowerBound)
552
+ elif request.isArrow():
553
+ assert sk.isAbstraction
554
+ v = request.arguments[0]
555
+ for l, newContext, b in self.sketchEnumeration(context, [v] + environment,
556
+ request.arguments[1],
557
+ sk.body,
558
+ upperBound=upperBound,
559
+ lowerBound=lowerBound,
560
+ maximumDepth=maximumDepth):
561
+ yield l, newContext, Abstraction(b)
562
+
563
+ else:
564
+ f, xs = sk.applicationParse()
565
+ if f.isIndex:
566
+ ft = environment[f.i].apply(context)
567
+ elif f.isInvented or f.isPrimitive:
568
+ context, ft = f.tp.instantiate(context)
569
+ elif f.isAbstraction:
570
+ assert False, "sketch is not in beta longform"
571
+ elif f.isHole:
572
+ assert False, "hole as function not yet supported"
573
+ elif f.isApplication:
574
+ assert False, "should never happen - bug in applicationParse"
575
+ else: assert False
576
+
577
+ try: context = context.unify(ft.returns(), request)
578
+ except UnificationFailure:
579
+ print("Exception: sketch is ill-typed")
580
+ return #so that we can continue evaluating
581
+ # raise SketchEnumerationFailure() #"sketch is ill-typed"
582
+ ft = ft.apply(context)
583
+ argumentRequests = ft.functionArguments()
584
+
585
+ assert len(argumentRequests) == len(xs)
586
+
587
+ yield from self.sketchApplication(context, environment,
588
+ f, xs, argumentRequests,
589
+ upperBound=upperBound,
590
+ lowerBound=lowerBound,
591
+ maximumDepth=maximumDepth - 1)
592
+
593
+
594
+ def sketchApplication(self, context, environment,
595
+ function, arguments, argumentRequests,
596
+ # Upper bound on the description length of all of
597
+ # the arguments
598
+ upperBound,
599
+ # Lower bound on the description length of all of
600
+ # the arguments
601
+ lowerBound=0.,
602
+ maximumDepth=20):
603
+ if upperBound < 0. or maximumDepth == 1:
604
+ return
605
+
606
+ if argumentRequests == []:
607
+ if lowerBound <= 0. and 0. < upperBound:
608
+ yield 0., context, function
609
+ else:
610
+ return
611
+ else:
612
+ argRequest = argumentRequests[0].apply(context)
613
+ laterRequests = argumentRequests[1:]
614
+ firstSketch = arguments[0]
615
+ laterSketches = arguments[1:]
616
+ for argL, newContext, arg in self.sketchEnumeration(context, environment, argRequest,
617
+ firstSketch,
618
+ upperBound=upperBound,
619
+ lowerBound=0.,
620
+ maximumDepth=maximumDepth):
621
+
622
+ newFunction = Application(function, arg)
623
+ for resultL, resultK, result in self.sketchApplication(newContext, environment, newFunction,
624
+ laterSketches, laterRequests,
625
+ upperBound=upperBound + argL,
626
+ lowerBound=lowerBound + argL,
627
+ maximumDepth=maximumDepth):
628
+
629
+ yield resultL + argL, resultK, result
630
+
631
+ def sketchLogLikelihood(self, request, full, sk, context=Context.EMPTY, environment=[]):
632
+ """
633
+ calculates mdl of full program 'full' from sketch 'sk'
634
+ """
635
+ if sk.isHole:
636
+ _, summary = self.likelihoodSummary(context, environment, request, full)
637
+ if summary is None:
638
+ eprint(
639
+ "FATAL: program [ %s ] does not have a likelihood summary." %
640
+ full, "r = ", request, "\n", self)
641
+ assert False
642
+ return summary.logLikelihood(self), context
643
+
644
+ elif request.isArrow():
645
+ assert sk.isAbstraction and full.isAbstraction
646
+ #assert sk.f == full.f #is this right? or do i need to recurse?
647
+ v = request.arguments[0]
648
+ return self.sketchLogLikelihood(request.arguments[1], full.body, sk.body, context=context, environment=[v] + environment)
649
+
650
+ else:
651
+ sk_f, sk_xs = sk.applicationParse()
652
+ full_f, full_xs = full.applicationParse()
653
+ if sk_f.isIndex:
654
+ assert sk_f == full_f, "sketch and full program don't match on an index"
655
+ ft = environment[sk_f.i].apply(context)
656
+ elif sk_f.isInvented or sk_f.isPrimitive:
657
+ assert sk_f == full_f, "sketch and full program don't match on a primitive"
658
+ context, ft = sk_f.tp.instantiate(context)
659
+ elif sk_f.isAbstraction:
660
+ assert False, "sketch is not in beta longform"
661
+ elif sk_f.isHole:
662
+ assert False, "hole as function not yet supported"
663
+ elif sk_f.isApplication:
664
+ assert False, "should never happen - bug in applicationParse"
665
+ else: assert False
666
+
667
+ try: context = context.unify(ft.returns(), request)
668
+ except UnificationFailure: assert False, "sketch is ill-typed"
669
+ ft = ft.apply(context)
670
+ argumentRequests = ft.functionArguments()
671
+
672
+ assert len(argumentRequests) == len(sk_xs) == len(full_xs) #this might not be true if holes??
673
+
674
+ return self.sketchllApplication(context, environment,
675
+ sk_f, sk_xs, full_f, full_xs, argumentRequests)
676
+
677
+ def sketchllApplication(self, context, environment,
678
+ sk_function, sk_arguments, full_function, full_arguments, argumentRequests):
679
+ if argumentRequests == []:
680
+ return torch.tensor([0.]).cuda(), context #does this make sense?
681
+ else:
682
+ argRequest = argumentRequests[0].apply(context)
683
+ laterRequests = argumentRequests[1:]
684
+
685
+ sk_firstSketch = sk_arguments[0]
686
+ full_firstSketch = full_arguments[0]
687
+ sk_laterSketches = sk_arguments[1:]
688
+ full_laterSketches = full_arguments[1:]
689
+
690
+ argL, newContext = self.sketchLogLikelihood(argRequest, full_firstSketch, sk_firstSketch, context=context, environment=environment)
691
+
692
+ #finish this...
693
+ sk_newFunction = Application(sk_function, sk_firstSketch) # is this redundant? maybe
694
+ full_newFunction = Application(full_function, full_firstSketch)
695
+
696
+ resultL, context = self.sketchllApplication(newContext, environment, sk_newFunction, sk_laterSketches,
697
+ full_newFunction, full_laterSketches, laterRequests)
698
+
699
+ return resultL + argL, context
700
+
701
+
702
+ def enumerateNearby(self, request, expr, distance=3.0):
703
+ """Enumerate programs with local mutations in subtrees with small description length"""
704
+ if distance <= 0:
705
+ yield expr
706
+ else:
707
+ def mutations(tp, loss):
708
+ for l, _, expr in self.enumeration(
709
+ Context.EMPTY, [], tp, distance - loss):
710
+ yield expr, l
711
+ yield from Mutator(self, mutations).execute(expr, request)
712
+
713
+
714
+ def enumerateHoles(self, request, expr, k=3, return_obj=Hole):
715
+ """Enumerate programs with a single hole within mdl distance"""
716
+ #TODO: make it possible to enumerate sketches with multiple holes
717
+ def mutations(tp, loss, is_left_application=False):
718
+ """
719
+ to allow applications lhs to become a hole,
720
+ remove the condition below and ignore all the is_left_application kwds
721
+ """
722
+ if not is_left_application:
723
+ yield return_obj(), 0
724
+ top_k = []
725
+ for expr, l in Mutator(self, mutations).execute(expr, request):
726
+ if len(top_k) > 0:
727
+ i, v = min(enumerate(top_k), key=lambda x:x[1][1])
728
+ if l > v[1]:
729
+ if len(top_k) >= k:
730
+ top_k[i] = (expr, l)
731
+ else:
732
+ top_k.append((expr, l))
733
+ elif len(top_k) < k:
734
+ top_k.append((expr, l))
735
+ else:
736
+ top_k.append((expr, l))
737
+ return sorted(top_k, key=lambda x:-x[1])
738
+
739
+ def untorch(self):
740
+ return Grammar(self.logVariable.data.tolist()[0],
741
+ [ (l.data.tolist()[0], t, p)
742
+ for l, t, p in self.productions],
743
+ continuationType=self.continuationType)
744
+
745
+ class LikelihoodSummary(object):
746
+ '''Summarizes the terms that will be used in a likelihood calculation'''
747
+
748
+ def __init__(self):
749
+ self.uses = {}
750
+ self.normalizers = {}
751
+ self.constant = 0.
752
+
753
+ def __str__(self):
754
+ return """LikelihoodSummary(constant = %f,
755
+ uses = {%s},
756
+ normalizers = {%s})""" % (self.constant,
757
+ ", ".join(
758
+ "%s: %d" % (k,
759
+ v) for k,
760
+ v in self.uses.items()),
761
+ ", ".join(
762
+ "%s: %d" % (k,
763
+ v) for k,
764
+ v in self.normalizers.items()))
765
+
766
+ def record(self, actual, possibles, constant=0.):
767
+ # Variables are all normalized to be $0
768
+ if isinstance(actual, Index):
769
+ actual = Index(0)
770
+
771
+ # Make it something that we can hash
772
+ possibles = frozenset(sorted(possibles, key=hash))
773
+
774
+ self.constant += constant
775
+ self.uses[actual] = self.uses.get(actual, 0) + 1
776
+ self.normalizers[possibles] = self.normalizers.get(possibles, 0) + 1
777
+
778
+ def join(self, other):
779
+ self.constant += other.constant
780
+ for k, v in other.uses.items():
781
+ self.uses[k] = self.uses.get(k, 0) + v
782
+ for k, v in other.normalizers.items():
783
+ self.normalizers[k] = self.normalizers.get(k, 0) + v
784
+
785
+ def logLikelihood(self, grammar):
786
+ return self.constant + \
787
+ sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items()) - \
788
+ sum(count * lse([grammar.expression2likelihood[p] for p in ps])
789
+ for ps, count in self.normalizers.items())
790
+ def logLikelihood_overlyGeneral(self, grammar):
791
+ """Calculates log likelihood of this summary, given that the summary might refer to productions that don't occur in the grammar"""
792
+ return self.constant + \
793
+ sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items()) - \
794
+ sum(count * lse([grammar.expression2likelihood.get(p,NEGATIVEINFINITY) for p in ps])
795
+ for ps, count in self.normalizers.items())
796
+ def numerator(self, grammar):
797
+ return self.constant + \
798
+ sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items())
799
+ def denominator(self, grammar):
800
+ return \
801
+ sum(count * lse([grammar.expression2likelihood[p] for p in ps])
802
+ for ps, count in self.normalizers.items())
803
+ def toUses(self):
804
+ from collections import Counter
805
+
806
+ possibleVariables = sum( count if Index(0) in ps else 0
807
+ for ps, count in self.normalizers.items() )
808
+ actualVariables = self.uses.get(Index(0), 0.)
809
+ actualUses = {k: v
810
+ for k, v in self.uses.items()
811
+ if not k.isIndex }
812
+ possibleUses = dict(Counter(p
813
+ for ps, count in self.normalizers.items()
814
+ for p_ in ps
815
+ if not p_.isIndex
816
+ for p in [p_]*count ))
817
+ return Uses(possibleVariables, actualVariables,
818
+ possibleUses, actualUses)
819
+
820
+
821
+ class Uses(object):
822
+ '''Tracks uses of different grammar productions'''
823
+
824
+ def __init__(self, possibleVariables=0., actualVariables=0.,
825
+ possibleUses={}, actualUses={}):
826
+ self.actualVariables = actualVariables
827
+ self.possibleVariables = possibleVariables
828
+ self.possibleUses = possibleUses
829
+ self.actualUses = actualUses
830
+
831
+ def __str__(self):
832
+ return "Uses(actualVariables = %f, possibleVariables = %f, actualUses = %s, possibleUses = %s)" %\
833
+ (self.actualVariables, self.possibleVariables, self.actualUses, self.possibleUses)
834
+
835
+ def __repr__(self): return str(self)
836
+
837
+ def __mul__(self, a):
838
+ return Uses(a * self.possibleVariables,
839
+ a * self.actualVariables,
840
+ {p: a * u for p, u in self.possibleUses.items()},
841
+ {p: a * u for p, u in self.actualUses.items()})
842
+
843
+ def __imul__(self, a):
844
+ self.possibleVariables *= a
845
+ self.actualVariables *= a
846
+ for p in self.possibleUses:
847
+ self.possibleUses[p] *= a
848
+ for p in self.actualUses:
849
+ self.actualUses[p] *= a
850
+ return self
851
+
852
+ def __rmul__(self, a):
853
+ return self * a
854
+
855
+ def __radd__(self, o):
856
+ if o == 0:
857
+ return self
858
+ return self + o
859
+
860
+ def __add__(self, o):
861
+ if o == 0:
862
+ return self
863
+
864
+ def merge(x, y):
865
+ z = x.copy()
866
+ for k, v in y.items():
867
+ z[k] = v + x.get(k, 0.)
868
+ return z
869
+ return Uses(self.possibleVariables + o.possibleVariables,
870
+ self.actualVariables + o.actualVariables,
871
+ merge(self.possibleUses, o.possibleUses),
872
+ merge(self.actualUses, o.actualUses))
873
+
874
+ def __iadd__(self, o):
875
+ self.possibleVariables += o.possibleVariables
876
+ self.actualVariables += o.actualVariables
877
+ for k, v in o.possibleUses.items():
878
+ self.possibleUses[k] = self.possibleUses.get(k, 0.) + v
879
+ for k, v in o.actualUses.items():
880
+ self.actualUses[k] = self.actualUses.get(k, 0.) + v
881
+ return self
882
+
883
+ @staticmethod
884
+ def join(z, *weightedUses):
885
+ """Consumes weightedUses"""
886
+ if not weightedUses:
887
+ Uses.empty
888
+ if len(weightedUses) == 1:
889
+ return weightedUses[0][1]
890
+ for w, u in weightedUses:
891
+ u *= exp(w - z)
892
+ total = Uses()
893
+ total.possibleVariables = sum(
894
+ u.possibleVariables for _, u in weightedUses)
895
+ total.actualVariables = sum(u.actualVariables for _, u in weightedUses)
896
+ total.possibleUses = defaultdict(float)
897
+ total.actualUses = defaultdict(float)
898
+ for _, u in weightedUses:
899
+ for k, v in u.possibleUses.items():
900
+ total.possibleUses[k] += v
901
+ for k, v in u.actualUses.items():
902
+ total.actualUses[k] += v
903
+ return total
904
+
905
+
906
+ Uses.empty = Uses()
907
+
908
+ class ContextualGrammar:
909
+ def __init__(self, noParent, variableParent, library):
910
+ self.noParent, self.variableParent, self.library = noParent, variableParent, library
911
+
912
+ self.productions = [(None,t,p) for _,t,p in self.noParent.productions ]
913
+ self.primitives = [p for _,_2,p in self.productions ]
914
+
915
+ self.continuationType = noParent.continuationType
916
+ assert variableParent.continuationType == self.continuationType
917
+
918
+ assert set(noParent.primitives) == set(variableParent.primitives)
919
+ assert set(variableParent.primitives) == set(library.keys())
920
+ for e,gs in library.items():
921
+ assert len(gs) == len(e.infer().functionArguments())
922
+ for g in gs:
923
+ assert set(g.primitives) == set(library.keys())
924
+ assert g.continuationType == self.continuationType
925
+
926
+ def untorch(self):
927
+ return ContextualGrammar(self.noParent.untorch(), self.variableParent.untorch(),
928
+ {e: [g.untorch() for g in gs ]
929
+ for e,gs in self.library.items() })
930
+
931
+ def randomWeights(self, r):
932
+ """returns a new grammar with random weights drawn from r. calls `r` w/ old weight"""
933
+ return ContextualGrammar(self.noParent.randomWeights(r),
934
+ self.variableParent.randomWeights(r),
935
+ {e: [g.randomWeights(r) for g in gs]
936
+ for e,gs in self.library.items() })
937
+ def __str__(self):
938
+ lines = ["No parent:",str(self.noParent),"",
939
+ "Variable parent:",str(self.variableParent),"",
940
+ ""]
941
+ for e,gs in self.library.items():
942
+ for j,g in enumerate(gs):
943
+ lines.extend(["Parent %s, argument index %s"%(e,j),
944
+ str(g),
945
+ ""])
946
+ return "\n".join(lines)
947
+
948
+ def json(self):
949
+ return {"noParent": self.noParent.json(),
950
+ "variableParent": self.variableParent.json(),
951
+ "productions": [{"program": str(e),
952
+ "arguments": [gp.json() for gp in gs ]}
953
+ for e,gs in self.library.items() ]}
954
+
955
+ @staticmethod
956
+ def fromGrammar(g):
957
+ return ContextualGrammar(g, g,
958
+ {e: [g]*len(e.infer().functionArguments())
959
+ for e in g.primitives })
960
+
961
+
962
+ class LS: # likelihood summary
963
+ def __init__(self, owner):
964
+ self.noParent = LikelihoodSummary()
965
+ self.variableParent = LikelihoodSummary()
966
+ self.library = {e: [LikelihoodSummary() for _ in gs] for e,gs in owner.library.items() }
967
+
968
+ def record(self, parent, parentIndex, actual, possibles, constant):
969
+ if parent is None: ls = self.noParent
970
+ elif parent.isIndex: ls = self.variableParent
971
+ else: ls = self.library[parent][parentIndex]
972
+ ls.record(actual, possibles, constant=constant)
973
+
974
+ def join(self, other):
975
+ self.noParent.join(other.noParent)
976
+ self.variableParent.join(other.variableParent)
977
+ for e,gs in self.library.items():
978
+ for g1,g2 in zip(gs, other.library[e]):
979
+ g1.join(g2)
980
+
981
+ def logLikelihood(self, owner):
982
+ return self.noParent.logLikelihood(owner.noParent) + \
983
+ self.variableParent.logLikelihood(owner.variableParent) + \
984
+ sum(r.logLikelihood(g)
985
+ for e, rs in self.library.items()
986
+ for r,g in zip(rs, owner.library[e]) )
987
+ def numerator(self, owner):
988
+ return self.noParent.numerator(owner.noParent) + \
989
+ self.variableParent.numerator(owner.variableParent) + \
990
+ sum(r.numerator(g)
991
+ for e, rs in self.library.items()
992
+ for r,g in zip(rs, owner.library[e]) )
993
+ def denominator(self, owner):
994
+ return self.noParent.denominator(owner.noParent) + \
995
+ self.variableParent.denominator(owner.variableParent) + \
996
+ sum(r.denominator(g)
997
+ for e, rs in self.library.items()
998
+ for r,g in zip(rs, owner.library[e]) )
999
+
1000
+ def likelihoodSummary(self, parent, parentIndex, context, environment, request, expression):
1001
+ if request.isArrow():
1002
+ assert expression.isAbstraction
1003
+ return self.likelihoodSummary(parent, parentIndex,
1004
+ context,
1005
+ [request.arguments[0]] + environment,
1006
+ request.arguments[1],
1007
+ expression.body)
1008
+ if parent is None: g = self.noParent
1009
+ elif parent.isIndex: g = self.variableParent
1010
+ else: g = self.library[parent][parentIndex]
1011
+ candidates = g.buildCandidates(request, context, environment,
1012
+ normalize=False, returnTable=True)
1013
+
1014
+ # A list of everything that would have been possible to use here
1015
+ possibles = [p for p in candidates.keys() if not p.isIndex]
1016
+ numberOfVariables = sum(p.isIndex for p in candidates.keys())
1017
+ if numberOfVariables > 0:
1018
+ possibles += [Index(0)]
1019
+
1020
+ f, xs = expression.applicationParse()
1021
+
1022
+ assert f in candidates
1023
+
1024
+ thisSummary = self.LS(self)
1025
+ thisSummary.record(parent, parentIndex,
1026
+ f, possibles,
1027
+ constant= -math.log(numberOfVariables) if f.isIndex else 0)
1028
+
1029
+ _, tp, context = candidates[f]
1030
+ argumentTypes = tp.functionArguments()
1031
+ assert len(xs) == len(argumentTypes)
1032
+
1033
+ for i, (argumentType, argument) in enumerate(zip(argumentTypes, xs)):
1034
+ argumentType = argumentType.apply(context)
1035
+ context, newSummary = self.likelihoodSummary(f, i,
1036
+ context, environment, argumentType, argument)
1037
+ thisSummary.join(newSummary)
1038
+
1039
+ return context, thisSummary
1040
+
1041
+ def closedLikelihoodSummary(self, request, expression):
1042
+ return self.likelihoodSummary(None,None,
1043
+ Context.EMPTY,[],
1044
+ request, expression)[1]
1045
+
1046
+ def logLikelihood(self, request, expression):
1047
+ return self.closedLikelihoodSummary(request, expression).logLikelihood(self)
1048
+
1049
+ def sample(self, request, maximumDepth=8, maxAttempts=None):
1050
+ attempts = 0
1051
+ while True:
1052
+ try:
1053
+ _, e = self._sample(None, None, Context.EMPTY, [], request, maximumDepth)
1054
+ return e
1055
+ except NoCandidates:
1056
+ if maxAttempts is not None:
1057
+ attempts += 1
1058
+ if attempts > maxAttempts: return None
1059
+ continue
1060
+
1061
+ def _sample(self, parent, parentIndex, context, environment, request, maximumDepth):
1062
+ if request.isArrow():
1063
+ context, body = self._sample(parent, parentIndex, context,
1064
+ [request.arguments[0]] + environment,
1065
+ request.arguments[1],
1066
+ maximumDepth)
1067
+ return context, Abstraction(body)
1068
+ if parent is None: g = self.noParent
1069
+ elif parent.isIndex: g = self.variableParent
1070
+ else: g = self.library[parent][parentIndex]
1071
+ candidates = g.buildCandidates(request, context, environment,
1072
+ normalize=True, returnProbabilities=True,
1073
+ mustBeLeaf=(maximumDepth <= 1))
1074
+ newType, chosenPrimitive, context = sampleDistribution(candidates)
1075
+
1076
+ xs = newType.functionArguments()
1077
+ returnValue = chosenPrimitive
1078
+
1079
+ for j,x in enumerate(xs):
1080
+ x = x.apply(context)
1081
+ context, x = self._sample(chosenPrimitive, j, context, environment, x, maximumDepth - 1)
1082
+ returnValue = Application(returnValue, x)
1083
+
1084
+ return context, returnValue
1085
+
1086
+ def expectedUsesMonteCarlo(self, request, debug=None):
1087
+ import numpy as np
1088
+ n = 0
1089
+ u = [0.]*len(self.primitives)
1090
+ primitives = list(sorted(self.primitives, key=str))
1091
+ noInventions = all( not p.isInvented for p in primitives )
1092
+ primitive2index = {primitive: i
1093
+ for i, primitive in enumerate(primitives)
1094
+ if primitive.isInvented or noInventions }
1095
+ eprint(primitive2index)
1096
+ ns = 10000
1097
+ with timing(f"calculated expected uses using Monte Carlo simulation w/ {ns} samples"):
1098
+ for _ in range(ns):
1099
+ p = self.sample(request, maxAttempts=0)
1100
+ if p is None: continue
1101
+ n += 1
1102
+ if debug and n < 10:
1103
+ eprint(debug, p)
1104
+ for _, child in p.walk():
1105
+ if child not in primitive2index: continue
1106
+ u[primitive2index[child]] += 1.0
1107
+ u = np.array(u)/n
1108
+ if debug:
1109
+ eprint(f"Got {n} samples. Feature vector:\n{u}")
1110
+ eprint(f"Likely used primitives: {[p for p,i in primitive2index.items() if u[i] > 0.5]}")
1111
+ eprint(f"Likely used primitive indices: {[i for p,i in primitive2index.items() if u[i] > 0.5]}")
1112
+ return u
1113
+
1114
+ def featureVector(self, _=None, requests=None, onlyInventions=True, normalize=True):
1115
+ """
1116
+ Returns the probabilities licensed by the type system.
1117
+ This is like the grammar productions, but with irrelevant junk removed.
1118
+ Its intended use case is for clustering; it should be strictly better than the raw transition matrix.
1119
+ """
1120
+ if requests is None:
1121
+ if self.continuationType: requests = {self.continuationType}
1122
+ elif any( 'REAL' == str(p) for p in self.primitives ): requests = set()
1123
+ elif any( 'STRING' == str(p) for p in self.primitives ): requests = {tlist(tcharacter)}
1124
+ else: requests = set()
1125
+ requests = {r.returns() for r in requests}
1126
+ features = []
1127
+ logWeights = []
1128
+ for l,t,p in sorted(self.noParent.productions,
1129
+ key=lambda z: str(z[2])):
1130
+ if onlyInventions and not p.isInvented: continue
1131
+ if any( canUnify(r, t.returns()) for r in requests ) or len(requests) == 0:
1132
+ logWeights.append(l)
1133
+ features.append(logWeights)
1134
+ for parent in sorted(self.primitives, key=str):
1135
+ if onlyInventions and not parent.isInvented: continue
1136
+ if parent not in self.library: continue
1137
+ argumentTypes = parent.infer().functionArguments()
1138
+ for j,g in enumerate(self.library[parent]):
1139
+ argumentType = argumentTypes[j]
1140
+ logWeights = []
1141
+ for l,t,p in sorted(g.productions,
1142
+ key=lambda z: str(z[2])):
1143
+ if onlyInventions and not p.isInvented: continue
1144
+ if canUnify(argumentType.returns(), t.returns()):
1145
+ logWeights.append(l)
1146
+ features.append(logWeights)
1147
+
1148
+ if normalize:
1149
+ features = [ [math.exp(w - z) for w in lw ]
1150
+ for lw in features
1151
+ if lw
1152
+ for z in [lse(lw)] ]
1153
+ import numpy as np
1154
+ return np.array([f
1155
+ for lw in features
1156
+ for f in lw])
1157
+
1158
+ def enumeration(self,context,environment,request,upperBound,
1159
+ parent=None, parentIndex=None,
1160
+ maximumDepth=20,
1161
+ lowerBound=0.):
1162
+ '''Enumerates all programs whose MDL satisfies: lowerBound <= MDL < upperBound'''
1163
+ if upperBound < 0 or maximumDepth == 1:
1164
+ return
1165
+
1166
+ if request.isArrow():
1167
+ v = request.arguments[0]
1168
+ for l, newContext, b in self.enumeration(context, [v] + environment,
1169
+ request.arguments[1],
1170
+ parent=parent, parentIndex=parentIndex,
1171
+ upperBound=upperBound,
1172
+ lowerBound=lowerBound,
1173
+ maximumDepth=maximumDepth):
1174
+ yield l, newContext, Abstraction(b)
1175
+ else:
1176
+ if parent is None: g = self.noParent
1177
+ elif parent.isIndex: g = self.variableParent
1178
+ else: g = self.library[parent][parentIndex]
1179
+
1180
+ candidates = g.buildCandidates(request, context, environment,
1181
+ normalize=True)
1182
+
1183
+ for l, t, p, newContext in candidates:
1184
+ mdl = -l
1185
+ if not (mdl < upperBound):
1186
+ continue
1187
+
1188
+ xs = t.functionArguments()
1189
+ for aL, aK, application in\
1190
+ self.enumerateApplication(newContext, environment, p, xs,
1191
+ parent=p,
1192
+ upperBound=upperBound + l,
1193
+ lowerBound=lowerBound + l,
1194
+ maximumDepth=maximumDepth - 1):
1195
+ yield aL + l, aK, application
1196
+
1197
+ def enumerateApplication(self, context, environment,
1198
+ function, argumentRequests,
1199
+ # Upper bound on the description length of all of
1200
+ # the arguments
1201
+ upperBound,
1202
+ # Lower bound on the description length of all of
1203
+ # the arguments
1204
+ lowerBound=0.,
1205
+ maximumDepth=20,
1206
+ parent=None,
1207
+ originalFunction=None,
1208
+ argumentIndex=0):
1209
+ assert parent is not None
1210
+ if upperBound < 0. or maximumDepth == 1:
1211
+ return
1212
+ if originalFunction is None:
1213
+ originalFunction = function
1214
+
1215
+ if argumentRequests == []:
1216
+ if lowerBound <= 0. and 0. < upperBound:
1217
+ yield 0., context, function
1218
+ else:
1219
+ return
1220
+ else:
1221
+ argRequest = argumentRequests[0].apply(context)
1222
+ laterRequests = argumentRequests[1:]
1223
+ for argL, newContext, arg in self.enumeration(context, environment, argRequest,
1224
+ parent=parent, parentIndex=argumentIndex,
1225
+ upperBound=upperBound,
1226
+ lowerBound=0.,
1227
+ maximumDepth=maximumDepth):
1228
+ if violatesSymmetry(originalFunction, arg, argumentIndex):
1229
+ continue
1230
+
1231
+ newFunction = Application(function, arg)
1232
+ for resultL, resultK, result in self.enumerateApplication(newContext, environment, newFunction,
1233
+ laterRequests,
1234
+ parent=parent,
1235
+ upperBound=upperBound + argL,
1236
+ lowerBound=lowerBound + argL,
1237
+ maximumDepth=maximumDepth,
1238
+ originalFunction=originalFunction,
1239
+ argumentIndex=argumentIndex + 1):
1240
+ yield resultL + argL, resultK, result
1241
+
1242
+
1243
+
1244
+
1245
+ def violatesSymmetry(f, x, argumentIndex):
1246
+ if not f.isPrimitive:
1247
+ return False
1248
+ while x.isApplication:
1249
+ x = x.f
1250
+ if not x.isPrimitive:
1251
+ return False
1252
+ f = f.name
1253
+ x = x.name
1254
+ if f == "car":
1255
+ return x == "cons" or x == "empty"
1256
+ if f == "cdr":
1257
+ return x == "cons" or x == "empty"
1258
+ if f == "+":
1259
+ return x == "0" or (argumentIndex == 1 and x == "+")
1260
+ if f == "-":
1261
+ return argumentIndex == 1 and x == "0"
1262
+ if f == "empty?":
1263
+ return x == "cons" or x == "empty"
1264
+ if f == "zero?":
1265
+ return x == "0" or x == "1"
1266
+ if f == "index" or f == "map" or f == "zip":
1267
+ return x == "empty"
1268
+ if f == "range":
1269
+ return x == "0"
1270
+ if f == "fold":
1271
+ return argumentIndex == 1 and x == "empty"
1272
+ return False
1273
+
1274
+ def batchLikelihood(jobs):
1275
+ """Takes as input a set of (program, request, grammar) and returns a dictionary mapping each of these to its likelihood under the grammar"""
1276
+ superGrammar = Grammar.uniform(list({p for _1,_2,g in jobs for p in g.primitives}),
1277
+ continuationType=list(jobs)[0][-1].continuationType)
1278
+ programsAndRequests = {(program, request)
1279
+ for program, request, grammar in jobs}
1280
+ with timing(f"Calculated {len(programsAndRequests)} likelihood summaries"):
1281
+ summary = {(program, request): superGrammar.closedLikelihoodSummary(request, program)
1282
+ for program, request in programsAndRequests}
1283
+ with timing(f"Calculated log likelihoods from summaries"):
1284
+ response = {}
1285
+ for program, request, grammar in jobs:
1286
+ fast = summary[(program, request)].logLikelihood_overlyGeneral(grammar)
1287
+ if False: # debugging
1288
+ slow = grammar.logLikelihood(request, program)
1289
+ print(program)
1290
+ eprint(grammar.closedLikelihoodSummary(request, program))
1291
+ eprint(superGrammar.closedLikelihoodSummary(request, program))
1292
+ print()
1293
+ assert abs(fast - slow) < 0.0001
1294
+ response[(program, request, grammar)] = fast
1295
+ return response
1296
+
1297
+ if __name__ == "__main__":
1298
+ from dreamcoder.domains.arithmetic.arithmeticPrimitives import *
1299
+ g = ContextualGrammar.fromGrammar(Grammar.uniform([k0,k1,addition, subtraction]))
1300
+ g = g.randomWeights(lambda *a: random.random())
1301
+ #p = Program.parse("(lambda (+ 1 $0))")
1302
+ request = arrow(tint,tint)
1303
+ for ll,_,p in g.enumeration(Context.EMPTY,[],request,
1304
+ 12.):
1305
+ ll_ = g.logLikelihood(request,p)
1306
+ print(ll,p,ll_)
1307
+ d = abs(ll - ll_)
1308
+ assert d < 0.0001
dreamcoder/likelihoodModel.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.task import Task, EvaluationTimeout
2
+ import gc
3
+ from dreamcoder.utilities import *
4
+ from collections import Counter
5
+ import math
6
+
7
+ from dreamcoder.domains.regex.groundtruthRegexes import gt_dict
8
+
9
+ gt_dict = {"Data column no. "+str(num): r_str for num, r_str in gt_dict.items()}
10
+
11
+ class AllOrNothingLikelihoodModel:
12
+ def __init__(self, timeout=None):
13
+ self.timeout = timeout
14
+
15
+ def score(self, program, task):
16
+ logLikelihood = task.logLikelihood(program, self.timeout)
17
+ return valid(logLikelihood), logLikelihood
18
+
19
+
20
+ class EuclideanLikelihoodModel:
21
+ """Likelihood is based on Euclidean distance between features"""
22
+
23
+ def __init__(self, featureExtractor, successCutoff=0.9):
24
+ self.extract = featureExtractor
25
+ self.successCutoff = successCutoff
26
+
27
+ def score(self, program, task):
28
+ taskFeat = self.extract.featuresOfTask(task)
29
+ progFeat = self.extract.featuresOfProgram(program, task.request)
30
+ assert len(taskFeat) == len(progFeat)
31
+ distance = sum((x1 - x2)**2 for x1, x2 in zip(taskFeat, progFeat))
32
+ logLikelihood = float(-distance) # FIXME: this is really naive
33
+ return exp(logLikelihood) > self.successCutoff, logLikelihood
34
+
35
+ def longest_common_substr(arr):
36
+ #array of examples
37
+
38
+ # Python 3 program to find the stem
39
+ # of given list of words
40
+ # function to find the stem (longest
41
+ # common substring) from the string array
42
+ # Determine size of the array
43
+ n = len(arr)
44
+
45
+ # Take first word from array
46
+ # as reference
47
+ s = arr[0]
48
+ l = len(s)
49
+ res = ""
50
+ for i in range(l) :
51
+ for j in range( i + 1, l + 1) :
52
+ # generating all possible substrings
53
+ # of our reference string arr[0] i.e s
54
+ stem = s[i:j]
55
+ k = 1
56
+ for k in range(1, n):
57
+
58
+ # Check if the generated stem is
59
+ # common to to all words
60
+ if stem not in arr[k]:
61
+ break
62
+
63
+ # If current substring is present in
64
+ # all strings and its length is greater
65
+ # than current result
66
+ if (k + 1 == n and len(res) < len(stem)): res = stem
67
+ return res
68
+
69
+ def add_string_constants(tasks):
70
+ for task in tasks:
71
+ task.str_const = longest_common_substr([example[1] for example in task.examples])
72
+ return tasks
73
+
74
+ def get_gt_ll(name, examples):
75
+ #gets groundtruth from dict
76
+ import pregex as pre
77
+ r_str = gt_dict[name]
78
+ preg = pre.create(r_str)
79
+
80
+ if type(examples[0]) == list:
81
+ examples = [ "".join(example) for example in examples]
82
+
83
+ s = sum( preg.match(example) for example in examples)
84
+ if s == float("-inf"):
85
+ print("bad for ", name)
86
+ print('preg:', preg)
87
+ print('preg sample:', [preg.sample() for i in range(3)])
88
+ print("exs", examples)
89
+ #assert False
90
+ return s
91
+
92
+
93
+ def add_cutoff_values(tasks, ll_cutoff):
94
+ from dreamcoder.domains.regex.makeRegexTasks import makeNewTasks
95
+ if ll_cutoff is None or ll_cutoff == "None":
96
+ for task in tasks:
97
+ task.ll_cutoff = None
98
+ return tasks
99
+ if ll_cutoff == "gt":
100
+ from dreamcoder.domains.regex.makeRegexTasks import regexHeldOutExamples
101
+ for task in tasks:
102
+ task.ll_cutoff = None
103
+ task.gt = get_gt_ll(task.name, [example[1] for example in task.examples])
104
+ task.gt_test = get_gt_ll(task.name,
105
+ [example[1] for example in regexHeldOutExamples(task) ])
106
+ return tasks
107
+ elif ll_cutoff == "plus":
108
+ for task in tasks:
109
+ task.ll_cutoff = regex_plus_bound([example[1] for example in task.examples])
110
+ return tasks
111
+ elif ll_cutoff == "bigram":
112
+ eprint("WARNING: using entire corpus to make bigram model")
113
+ #this means i do it twice, which is eh whatever
114
+ model = make_corpus_bigram(show_tasks(makeNewTasks()))
115
+ for task in tasks:
116
+ task.ll_cutoff = bigram_corpus_score([example[1] for example in task.examples], model)
117
+ return tasks
118
+ elif ll_cutoff =="unigram":
119
+ eprint("WARNING: using entire corpus to make unigram model")
120
+ #this means i do it twice, which is eh whatever
121
+ model = make_corpus_unigram(show_tasks(makeNewTasks()))
122
+ for task in tasks:
123
+ task.ll_cutoff = unigram_corpus_score([example[1] for example in task.examples], model)
124
+ return tasks
125
+ elif ll_cutoff =="mix":
126
+ eprint("WARNING: using entire corpus to make bigram model")
127
+ eprint("WARNING: using entire corpus to make unigram model")
128
+ #this means i do it twice, which is eh whatever
129
+ unigram = make_corpus_unigram(show_tasks(makeNewTasks()))
130
+ bigram = make_corpus_bigram(show_tasks(makeNewTasks()))
131
+ for task in tasks:
132
+ uniscore = unigram_corpus_score([example[1] for example in task.examples], unigram)
133
+ biscore = bigram_corpus_score([example[1] for example in task.examples], bigram)
134
+ task.ll_cutoff = math.log(0.75*math.exp(biscore) + 0.25*math.exp(uniscore))
135
+ return tasks
136
+ else:
137
+ eprint("not implemented")
138
+ eprint("cutoff val:")
139
+ eprint(ll_cutoff)
140
+ assert False
141
+
142
+ def show_tasks(dataset):
143
+ task_list = []
144
+ for task in dataset:
145
+ task_list.append([example[1] for example in task.examples])
146
+ return task_list
147
+
148
+ def regex_plus_bound(X):
149
+ from pregex import pregex
150
+ c = Counter(X)
151
+ regexes = [
152
+ pregex.create(".+"),
153
+ pregex.create("\d+"),
154
+ pregex.create("\w+"),
155
+ pregex.create("\s+"),
156
+ pregex.create("\\u+"),
157
+ pregex.create("\l+")]
158
+ regex_scores = []
159
+ for r in regexes:
160
+ regex_scores.append(sum(c[x] * r.match(x) for x in c)/float(sum([len(x) for x in X])) )
161
+ return max(regex_scores)
162
+
163
+
164
+ def make_corpus_unigram(C):
165
+ str_list = [example + '\n' for task in C for example in task]
166
+ c = Counter(char for example in str_list for char in example )
167
+ n = sum(c.values())
168
+
169
+ logp = {x:math.log(c[x]/n) for x in c}
170
+ return logp
171
+
172
+ def unigram_corpus_score(X, logp):
173
+ task_ll = 0
174
+ for x in X:
175
+ x = x + '\n'
176
+ task_ll += sum( logp.get(c, float('-inf')) for c in x)/len(x)
177
+
178
+ ll = task_ll/len(X)
179
+ return ll
180
+
181
+ def unigram_task_score(X):
182
+ """
183
+ Given a list of strings, X, calculate the maximum log-likelihood per character for a unigram model over characters (including STOP symbol)
184
+ """
185
+ c = Counter(x for s in X for x in s)
186
+ c.update("end" for s in X)
187
+ n = sum(c.values())
188
+ logp = {x:math.log(c[x]/n) for x in c}
189
+ return sum(c[x]*logp[x] for x in c)/n
190
+
191
+ def make_corpus_bigram(C):
192
+ #using newline as "end"
193
+ #C is a list of tasks
194
+
195
+ #make one big list of strings
196
+ str_list = [example + '\n' for task in C for example in task]
197
+
198
+ #make list of
199
+ head_count = Counter(element[0] for element in str_list)
200
+ head_n = sum(head_count.values())
201
+ head_logp = {x:math.log(head_count[x]/head_n) for x in head_count}
202
+
203
+ body_count = Counter(element[i:i+2] for element in str_list for i in range(len(element)-1))
204
+ body_bigram_n = sum(body_count.values())
205
+ #body_count/body_bigram_n gives the joint of a bigram
206
+ body_character_n = Counter(char for element in str_list for char in element)
207
+ body_unigram_n = sum(body_character_n.values())
208
+
209
+ body_logp = {x:math.log(body_count[x] / body_bigram_n / body_character_n[x[0]] * body_unigram_n) for x in body_count}
210
+
211
+ return {**head_logp, **body_logp}
212
+
213
+ def bigram_corpus_score(X, logp):
214
+ #assume you have a logp dict
215
+ task_ll = 0
216
+ for x in X:
217
+ bigram_list = [x[0]] + [x[i:i+2] for i in range(len(x)-1)] + [x[-1] + '\n']
218
+ bigram_list = [ ''.join(b) if isinstance(b,list) else b
219
+ for b in bigram_list ]
220
+
221
+ string_ll = sum(logp.get(bigram, float('-inf')) for bigram in bigram_list) #/(len(x) + 1)
222
+
223
+ task_ll += string_ll
224
+
225
+ ll = task_ll #/len(X)
226
+ return ll
227
+
228
+
229
+ class ProbabilisticLikelihoodModel:
230
+
231
+ def __init__(self, timeout):
232
+ self.timeout = timeout
233
+ # i need timeout
234
+
235
+ def score(self, program, task):
236
+ # need a try, catch here for problems, and for timeouts
237
+ # can copy task.py for the timeout structure
238
+ try:
239
+ def timeoutCallBack(_1, _2): raise EvaluationTimeout()
240
+ signal.signal(signal.SIGVTALRM, timeoutCallBack)
241
+ signal.setitimer(signal.ITIMER_VIRTUAL, self.timeout)
242
+ try:
243
+ string_pregex = program.evaluate([])
244
+ # if 'left_paren' in program.show(False):
245
+ #eprint("string_pregex:", string_pregex)
246
+ #eprint("string_pregex:", string_pregex)
247
+ preg = string_pregex # pregex.create(string_pregex)
248
+ except IndexError:
249
+ # free variable
250
+ return False, NEGATIVEINFINITY
251
+ except Exception as e:
252
+ eprint("Exception during evaluation:", e)
253
+ if "Attempt to evaluate fragment variable" in e:
254
+ eprint("program (bc fragment error)", program)
255
+ return False, NEGATIVEINFINITY
256
+
257
+ #tries and catches
258
+
259
+ # include prior somehow
260
+ # right now, just summing up log likelihoods. IDK if this is correct.
261
+ # also not using prior at all.
262
+
263
+ cum_ll = 0
264
+
265
+ example_list = [example[1] for example in task.examples]
266
+ c_example_list = Counter(example_list)
267
+
268
+ for c_example in c_example_list:
269
+ #might want a try, except around the following line:
270
+
271
+ try:
272
+ #eprint("about to match", program)
273
+ #print("preg:", preg)
274
+ ll = preg.match(c_example)
275
+ #eprint("completed match", ll, program)
276
+ except ValueError as e:
277
+ eprint("ValueError:", e)
278
+ ll = float('-inf')
279
+
280
+ #eprint("pregex:", string_pregex)
281
+ #eprint("example[1]", example[1])
282
+
283
+ if ll == float('-inf'):
284
+ return False, NEGATIVEINFINITY
285
+ else:
286
+ #ll_per_char = ll/float(len(example[1]))
287
+ #cum_ll_per_char += ll_per_char
288
+
289
+ cum_ll += c_example_list[c_example] * ll
290
+
291
+ #normalized_cum_ll_per_char = cum_ll_per_char/float(len(task.examples))
292
+ #avg_char_num = sum([len(example[1]) for example in task.examples])/float(len(task.examples))
293
+
294
+ #cutoff_ll = regex_plus_bound(example_list)
295
+
296
+ normalized_cum_ll = cum_ll/ float(sum([len(example) for example in example_list]))
297
+
298
+
299
+
300
+ #TODO: change the way normalized_cum_ll is calculated
301
+ #TODO: refactor to pass in bigram_model, and others
302
+ #TODO: refactor to do 95% certainty thing josh wants
303
+ success = normalized_cum_ll > task.ll_cutoff
304
+
305
+
306
+
307
+ #eprint("cutoff_ll:", cutoff_ll, ", norm_cum_ll:", normalized_cum_ll)
308
+
309
+ return success, normalized_cum_ll
310
+
311
+ except EvaluationTimeout:
312
+ eprint("Timed out while evaluating", program)
313
+ return False, NEGATIVEINFINITY
314
+ finally:
315
+ signal.signal(signal.SIGVTALRM, lambda *_: None)
316
+ signal.setitimer(signal.ITIMER_VIRTUAL, 0)
317
+
318
+
319
+ try:
320
+ import torch
321
+ import torch.nn as nn
322
+ import torch.nn.functional as F
323
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
324
+ from torch.autograd import Variable
325
+
326
+ class FeatureDiscriminatorLikelihoodModel(nn.Module):
327
+ def __init__(self, tasks, featureExtractor,
328
+ successCutoff=0.6, H=8, trainingSuccessRatio=0.5):
329
+ super(FeatureDiscriminatorLikelihoodModel, self).__init__()
330
+ self.extract = featureExtractor
331
+ self.successCutoff = successCutoff
332
+ self.trainingSuccessRatio = trainingSuccessRatio
333
+
334
+ self.W = nn.Linear(featureExtractor.outputDimensionality, H)
335
+ self.output = nn.Linear(H, 1)
336
+
337
+ # training on initialization
338
+ self.train(tasks)
339
+
340
+ def forward(self, examples):
341
+ """
342
+ Examples is a list of feature sets corresponding to a particular example.
343
+ Output in [0,1] whether all examples correspond to the same program
344
+ """
345
+ assert all(
346
+ len(x) == self.extract.outputDimensionality for x in examples)
347
+ examples = [F.tanh(self.W(ex)) for ex in examples]
348
+ maxed, _ = torch.max(torch.stack(examples), dim=0)
349
+ return F.sigmoid(self.output(maxed))
350
+
351
+ def train(self, tasks, steps=400):
352
+ # list of list of features for each example in each task
353
+ optimizer = torch.optim.Adam(self.parameters())
354
+ with timing("Trained discriminator"):
355
+ losses = []
356
+ for i in range(steps):
357
+ self.zero_grad()
358
+ if random.random() <= self.trainingSuccessRatio:
359
+ # success
360
+ t = random.choice(tasks)
361
+ features = [self.extract.featuresOfTask(
362
+ Task(t.name, t.request, [ex], t.features))
363
+ for ex in t.examples]
364
+ loss = (self(features) - 1.0)**2
365
+ else:
366
+ # fail
367
+ t1, t2 = random.sample(tasks, 2)
368
+ features1 = [self.extract.featuresOfTask(
369
+ Task(t1.name, t1.request, [ex], t1.features))
370
+ for ex in t1.examples[:len(t1.examples) / 2]]
371
+ features2 = [self.extract.featuresOfTask(
372
+ Task(t2.name, t2.request, [ex], t2.features))
373
+ for ex in t2.examples[len(t2.examples) / 2:]]
374
+ features = features1 + features2
375
+ loss = self(features)**2
376
+
377
+ loss.backward()
378
+ optimizer.step()
379
+ losses.append(loss.data[0])
380
+ if not i % 50:
381
+ eprint(
382
+ "Discriminator Epoch",
383
+ i,
384
+ "Loss",
385
+ sum(losses) /
386
+ len(losses))
387
+ gc.collect()
388
+
389
+ def score(self, program, task):
390
+ taskFeatures = self.extract.featuresOfTask(task)
391
+ progFeatures = self.extract.featuresOfProgram(
392
+ program, task.request)
393
+ likelihood = self([taskFeatures] + [progFeatures])
394
+ likelihood = float(likelihood)
395
+ return likelihood > self.successCutoff, log(likelihood)
396
+ except ImportError:
397
+ pass
398
+
399
+
400
+ if __name__=="__main__":
401
+
402
+ arr = ['MAM.OSBS.2014.06', 'MAM.OSBS.2013.07', 'MAM.OSBS.2013.09', 'MAM.OSBS.2014.05', 'MAM.OSBS.2014.11']
403
+ stems = longest_common_substr(arr)
404
+ print(stems)
405
+
406
+
407
+
dreamcoder/primitiveGraph.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.program import *
2
+
3
+ def graphPrimitives(result, prefix, view=False):
4
+ try:
5
+ from graphviz import Digraph
6
+ except:
7
+ eprint("You are missing the graphviz library - cannot graph primitives!")
8
+ return
9
+
10
+
11
+ primitives = { p
12
+ for g in result.grammars
13
+ for p in g.primitives
14
+ if p.isInvented }
15
+ age = {p: min(j for j,g in enumerate(result.grammars) if p in g.primitives) + 1
16
+ for p in primitives }
17
+
18
+
19
+
20
+ ages = set(age.values())
21
+ age2primitives = {a: {p for p,ap in age.items() if a == ap }
22
+ for a in ages}
23
+
24
+ def lb(s,T=20):
25
+ s = s.split()
26
+ l = []
27
+ n = 0
28
+ for w in s:
29
+ if n + len(w) > T:
30
+ l.append("<br />")
31
+ n = 0
32
+ n += len(w)
33
+ l.append(w)
34
+ return " ".join(l)
35
+
36
+ nameSimplification = {
37
+ "fix1": 'Y',
38
+ "tower_loopM": "for",
39
+ "tower_embed": "get/set",
40
+ "moveHand": "move",
41
+ "reverseHand": "reverse",
42
+ "logo_DIVA": '/',
43
+ "logo_epsA": 'ε',
44
+ "logo_epsL": 'ε',
45
+ "logo_IFTY": '∞',
46
+ "logo_forLoop": "for",
47
+ "logo_UA": "2π",
48
+ "logo_FWRT": "move",
49
+ "logo_UL": "1",
50
+ "logo_SUBA": "-",
51
+ "logo_ZL": "0",
52
+ "logo_ZA": "0",
53
+ "logo_MULL": "*",
54
+ "logo_MULA": "*",
55
+ "logo_PT": "pen-up",
56
+ "logo_GETSET": "get/set"
57
+ }
58
+
59
+
60
+ name = {}
61
+ simplification = {}
62
+ depth = {}
63
+ def getName(p):
64
+ if p in name: return name[p]
65
+ children = {k: getName(k)
66
+ for _,k in p.body.walk()
67
+ if k.isInvented}
68
+ simplification_ = p.body
69
+ for k,childName in children.items():
70
+ simplification_ = simplification_.substitute(k, Primitive(childName,None,None))
71
+ for original, simplified in nameSimplification.items():
72
+ simplification_ = simplification_.substitute(Primitive(original,None,None),
73
+ Primitive(simplified,None,None))
74
+ name[p] = "f%d"%len(name)
75
+ simplification[p] = name[p] + '=' + lb(prettyProgram(simplification_, Lisp=True))
76
+ depth[p] = 1 + max([depth[k] for k in children] + [0])
77
+ return name[p]
78
+
79
+ for p in primitives:
80
+ getName(p)
81
+
82
+ depths = {depth[p] for p in primitives}
83
+ depth2primitives = {d: {p for p in primitives if depth[p] == d }
84
+ for d in depths}
85
+
86
+ englishDescriptions = {"#(lambda (lambda (map (lambda (index $0 $2)) (range $0))))":
87
+ "Prefix",
88
+ "#(lambda (lambda (fold $0 $1 (lambda (lambda (cons $1 $0))))))":
89
+ "Append",
90
+ "#(lambda (cons LPAREN (#(lambda (lambda (fold $0 $1 (lambda (lambda (cons $1 $0)))))) (cons RPAREN empty) $0)))":
91
+ "Enclose w/ parens",
92
+ "#(lambda (unfold $0 (lambda (empty? $0)) (lambda (car $0)) (lambda (#(lambda (lambda (fold $1 $1 (lambda (lambda (cdr (if (char-eq? $1 $2) $3 $0))))))) $0 SPACE))))":
93
+ "Abbreviate",
94
+ "#(lambda (lambda (fold $1 $1 (lambda (lambda (cdr (if (char-eq? $1 $2) $3 $0)))))))":
95
+ "Drop until char",
96
+ "#(lambda (lambda (fold $1 $1 (lambda (lambda (if (char-eq? $1 $2) empty (cons $1 $0)))))))":
97
+ "Take until char",
98
+ "#(lambda (lambda (#(lambda (lambda (fold $0 $1 (lambda (lambda (cons $1 $0)))))) (cons $0 $1))))":
99
+ "Append char",
100
+ "#(lambda (lambda (map (lambda (if (char-eq? $0 $1) $2 $0)))))":
101
+ "Substitute char",
102
+ "#(lambda (lambda (length (unfold $1 (lambda (char-eq? (car $0) $1)) (lambda ',') (lambda (cdr $0))))))":
103
+ "Index of char",
104
+ "#(lambda (#(lambda (lambda (fold $0 $1 (lambda (lambda (cons $1 $0)))))) $0 STRING))":
105
+ "Append const",
106
+ "#(lambda (lambda (fold $1 $1 (lambda (lambda (fold $0 $0 (lambda (lambda (cdr (if (char-eq? $1 $4) $0 (cons $1 $0)))))))))))":
107
+ "Last word",
108
+ "#(lambda (lambda (cons (car $1) (cons '.' (cons (car $0) (cons '.' empty))))))":
109
+ "Abbreviate name",
110
+ "#(lambda (lambda (cons (car $1) (cons $0 empty))))":
111
+ "First char+char",
112
+ "#(lambda (#(lambda (lambda (fold $0 $1 (lambda (lambda (cons $1 $0)))))) (#(lambda (lambda (fold $1 $1 (lambda (lambda (fold $0 $0 (lambda (lambda (cdr (if (char-eq? $1 $4) $0 (cons $1 $0))))))))))) STRING (index (length (cdr $0)) $0)) $0))":
113
+ "Ensure suffix"
114
+
115
+ }
116
+
117
+ def makeUnorderedGraph(fn):
118
+ g = Digraph()
119
+ g.graph_attr['rankdir'] = 'LR'
120
+
121
+ for p in primitives:
122
+ g.node(getName(p),
123
+ label="<%s>"%simplification[p])
124
+ for p in primitives:
125
+ children = {k
126
+ for _,k in p.body.walk()
127
+ if k.isInvented}
128
+ for k in children:
129
+ g.edge(name[k],name[p])
130
+ try:
131
+ g.render(fn,view=view)
132
+ eprint("Exported primitive graph to",fn)
133
+ except:
134
+ eprint("Got some kind of error while trying to render primitive graph! Did you install graphviz/dot?")
135
+
136
+
137
+
138
+ def makeGraph(ordering, fn):
139
+ g = Digraph()
140
+ g.graph_attr['rankdir'] = 'RL'
141
+
142
+ if False:
143
+ with g.subgraph(name='cluster_0') as sg:
144
+ sg.graph_attr['rank'] = 'same'
145
+ sg.attr(label='Primitives')
146
+ for j, primitive in enumerate(result.grammars[-1].primitives):
147
+ if primitive.isInvented: continue
148
+ sg.node("primitive%d"%j, label=str(primitive))
149
+
150
+ for o in sorted(ordering.keys()):
151
+ with g.subgraph(name='cluster_%d'%o) as sg:
152
+ sg.graph_attr['rank'] = 'same'
153
+ #sg.attr(label='Depth %d'%o)
154
+ for p in ordering[o]:
155
+ if str(p) in englishDescriptions:
156
+ thisLabel = '<<font face="boldfontname"><u>%s</u></font><br />%s>'%(englishDescriptions[str(p)],simplification[p])
157
+ else:
158
+ eprint("WARNING: Do not have an English description of:\n",p)
159
+ eprint()
160
+ thisLabel = "<%s>"%simplification[p]
161
+ sg.node(getName(p),
162
+ label=thisLabel)
163
+
164
+ for p in ordering[o]:
165
+ children = {k
166
+ for _,k in p.body.walk()
167
+ if k.isInvented}
168
+ for k in children:
169
+ g.edge(name[k],name[p])
170
+
171
+ eprint("Exporting primitive graph to",fn)
172
+ try:
173
+ g.render(fn,view=view)
174
+ except Exception as e:
175
+ eprint("Got some kind of error while trying to render primitive graph! Did you install graphviz/dot?")
176
+ print(e)
177
+
178
+
179
+
180
+ makeGraph(depth2primitives,prefix+'depth.pdf')
181
+ makeUnorderedGraph(prefix+'unordered.pdf')
182
+ #makeGraph(age2primitives,prefix+'iter.pdf')
dreamcoder/program.py ADDED
@@ -0,0 +1,1214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from dreamcoder.type import *
4
+ from dreamcoder.utilities import *
5
+
6
+ from time import time
7
+ import math
8
+
9
+
10
+ class InferenceFailure(Exception):
11
+ pass
12
+
13
+
14
+ class ShiftFailure(Exception):
15
+ pass
16
+
17
+ class RunFailure(Exception):
18
+ pass
19
+
20
+
21
+ class Program(object):
22
+ def __repr__(self): return str(self)
23
+
24
+ def __ne__(self, o): return not (self == o)
25
+
26
+ def __str__(self): return self.show(False)
27
+
28
+ def canHaveType(self, t):
29
+ try:
30
+ context, actualType = self.inferType(Context.EMPTY, [], {})
31
+ context, t = t.instantiate(context)
32
+ context.unify(t, actualType)
33
+ return True
34
+ except UnificationFailure as e:
35
+ return False
36
+
37
+ def betaNormalForm(self):
38
+ n = self
39
+ while True:
40
+ np = n.betaReduce()
41
+ if np is None: return n
42
+ n = np
43
+
44
+ def infer(self):
45
+ try:
46
+ return self.inferType(Context.EMPTY, [], {})[1].canonical()
47
+ except UnificationFailure as e:
48
+ raise InferenceFailure(self, e)
49
+
50
+ def uncurry(self):
51
+ t = self.infer()
52
+ a = len(t.functionArguments())
53
+ e = self
54
+ existingAbstractions = 0
55
+ while e.isAbstraction:
56
+ e = e.body
57
+ existingAbstractions += 1
58
+ newAbstractions = a - existingAbstractions
59
+ assert newAbstractions >= 0
60
+
61
+ # e is the body stripped of abstractions. we are going to pile
62
+ # some more lambdas at the front, so free variables in e
63
+ # (which were bound to the stripped abstractions) need to be
64
+ # shifted by the number of abstractions that we will be adding
65
+ e = e.shift(newAbstractions)
66
+
67
+ for n in reversed(range(newAbstractions)):
68
+ e = Application(e, Index(n))
69
+ for _ in range(a):
70
+ e = Abstraction(e)
71
+
72
+ assert self.infer() == e.infer(), \
73
+ "FATAL: uncurry has a bug. %s : %s, but uncurried to %s : %s" % (self, self.infer(),
74
+ e, e.infer())
75
+ return e
76
+
77
+ def wellTyped(self):
78
+ try:
79
+ self.infer()
80
+ return True
81
+ except InferenceFailure:
82
+ return False
83
+
84
+ def runWithArguments(self, xs):
85
+ f = self.evaluate([])
86
+ for x in xs:
87
+ f = f(x)
88
+ return f
89
+
90
+ def applicationParses(self): yield self, []
91
+
92
+ def applicationParse(self): return self, []
93
+
94
+ @property
95
+ def closed(self):
96
+ for surroundingAbstractions, child in self.walk():
97
+ if isinstance(child, FragmentVariable):
98
+ return False
99
+ if isinstance(child, Index) and child.free(
100
+ surroundingAbstractions):
101
+ return False
102
+ return True
103
+
104
+ @property
105
+ def numberOfFreeVariables(expression):
106
+ n = 0
107
+ for surroundingAbstractions, child in expression.walk():
108
+ # Free variable
109
+ if isinstance(child, Index) and child.free(
110
+ surroundingAbstractions):
111
+ n = max(n, child.i - surroundingAbstractions + 1)
112
+ return n
113
+
114
+ def freeVariables(self):
115
+ for surroundingAbstractions, child in self.walk():
116
+ if child.isIndex and child.i >= surroundingAbstractions:
117
+ yield child.i - surroundingAbstractions
118
+
119
+ @property
120
+ def isIndex(self): return False
121
+
122
+ @property
123
+ def isUnion(self): return False
124
+
125
+ @property
126
+ def isApplication(self): return False
127
+
128
+ @property
129
+ def isAbstraction(self): return False
130
+
131
+ @property
132
+ def isPrimitive(self): return False
133
+
134
+ @property
135
+ def isInvented(self): return False
136
+
137
+ @property
138
+ def isHole(self): return False
139
+
140
+ @staticmethod
141
+ def parse(s):
142
+ s = parseSExpression(s)
143
+ def p(e):
144
+ if isinstance(e,list):
145
+ if e[0] == '#':
146
+ assert len(e) == 2
147
+ return Invented(p(e[1]))
148
+ if e[0] == 'lambda':
149
+ assert len(e) == 2
150
+ return Abstraction(p(e[1]))
151
+ f = p(e[0])
152
+ for x in e[1:]:
153
+ f = Application(f,p(x))
154
+ return f
155
+ assert isinstance(e,str)
156
+ if e[0] == '$': return Index(int(e[1:]))
157
+ if e in Primitive.GLOBALS: return Primitive.GLOBALS[e]
158
+ if e == '??' or e == '?': return FragmentVariable.single
159
+ if e == '<HOLE>': return Hole.single
160
+ raise ParseFailure((s,e))
161
+ return p(s)
162
+
163
+ @staticmethod
164
+ def _parse(s,n):
165
+ while n < len(s) and s[n].isspace():
166
+ n += 1
167
+ for p in [
168
+ Application,
169
+ Abstraction,
170
+ Index,
171
+ Invented,
172
+ FragmentVariable,
173
+ Hole,
174
+ Primitive]:
175
+ try:
176
+ return p._parse(s,n)
177
+ except ParseFailure:
178
+ continue
179
+ raise ParseFailure(s)
180
+
181
+ # parser helpers
182
+ @staticmethod
183
+ def parseConstant(s,n,*constants):
184
+ for constant in constants:
185
+ try:
186
+ for i,c in enumerate(constant):
187
+ if i + n >= len(s) or s[i + n] != c: raise ParseFailure(s)
188
+ return n + len(constant)
189
+ except ParseFailure: continue
190
+ raise ParseFailure(s)
191
+
192
+ @staticmethod
193
+ def parseHumanReadable(s):
194
+ s = parseSExpression(s)
195
+ def p(s, environment):
196
+ if isinstance(s, list) and s[0] in ['lambda','\\']:
197
+ assert isinstance(s[1], list) and len(s) == 3
198
+ newEnvironment = list(reversed(s[1])) + environment
199
+ e = p(s[2], newEnvironment)
200
+ for _ in s[1]: e = Abstraction(e)
201
+ return e
202
+ if isinstance(s, list):
203
+ a = p(s[0], environment)
204
+ for x in s[1:]:
205
+ a = Application(a, p(x, environment))
206
+ return a
207
+ for j,v in enumerate(environment):
208
+ if s == v: return Index(j)
209
+ if s in Primitive.GLOBALS: return Primitive.GLOBALS[s]
210
+ assert False, f"could not parse {s}"
211
+ return p(s, [])
212
+
213
+
214
+
215
+
216
+ class Application(Program):
217
+ '''Function application'''
218
+
219
+ def __init__(self, f, x):
220
+ self.f = f
221
+ self.x = x
222
+ self.hashCode = None
223
+ self.isConditional = (not isinstance(f,int)) and \
224
+ f.isApplication and \
225
+ f.f.isApplication and \
226
+ f.f.f.isPrimitive and \
227
+ f.f.f.name == "if"
228
+ if self.isConditional:
229
+ self.falseBranch = x
230
+ self.trueBranch = f.x
231
+ self.branch = f.f.x
232
+ else:
233
+ self.falseBranch = None
234
+ self.trueBranch = None
235
+ self.branch = None
236
+
237
+ def betaReduce(self):
238
+ # See if either the function or the argument can be reduced
239
+ f = self.f.betaReduce()
240
+ if f is not None: return Application(f,self.x)
241
+ x = self.x.betaReduce()
242
+ if x is not None: return Application(self.f,x)
243
+
244
+ # Neither of them could be reduced. Is this not a redex?
245
+ if not self.f.isAbstraction: return None
246
+
247
+ # Perform substitution
248
+ b = self.f.body
249
+ v = self.x
250
+ return b.substitute(Index(0), v.shift(1)).shift(-1)
251
+
252
+ def isBetaLong(self):
253
+ return (not self.f.isAbstraction) and self.f.isBetaLong() and self.x.isBetaLong()
254
+
255
+ def freeVariables(self):
256
+ return self.f.freeVariables() | self.x.freeVariables()
257
+
258
+ def clone(self): return Application(self.f.clone(), self.x.clone())
259
+
260
+ def annotateTypes(self, context, environment):
261
+ self.f.annotateTypes(context, environment)
262
+ self.x.annotateTypes(context, environment)
263
+ r = context.makeVariable()
264
+ context.unify(arrow(self.x.annotatedType, r), self.f.annotatedType)
265
+ self.annotatedType = r.applyMutable(context)
266
+
267
+
268
+ @property
269
+ def isApplication(self): return True
270
+
271
+ def __eq__(
272
+ self,
273
+ other): return isinstance(
274
+ other,
275
+ Application) and self.f == other.f and self.x == other.x
276
+
277
+ def __hash__(self):
278
+ if self.hashCode is None:
279
+ self.hashCode = hash((hash(self.f), hash(self.x)))
280
+ return self.hashCode
281
+
282
+ """Because Python3 randomizes the hash function, we need to never pickle the hash"""
283
+ def __getstate__(self):
284
+ return self.f, self.x, self.isConditional, self.falseBranch, self.trueBranch, self.branch
285
+ def __setstate__(self, state):
286
+ try:
287
+ self.f, self.x, self.isConditional, self.falseBranch, self.trueBranch, self.branch = state
288
+ except ValueError:
289
+ # backward compatibility
290
+ assert 'x' in state
291
+ assert 'f' in state
292
+ f = state['f']
293
+ x = state['x']
294
+ self.f = f
295
+ self.x = x
296
+ self.isConditional = (not isinstance(f,int)) and \
297
+ f.isApplication and \
298
+ f.f.isApplication and \
299
+ f.f.f.isPrimitive and \
300
+ f.f.f.name == "if"
301
+ if self.isConditional:
302
+ self.falseBranch = x
303
+ self.trueBranch = f.x
304
+ self.branch = f.f.x
305
+ else:
306
+ self.falseBranch = None
307
+ self.trueBranch = None
308
+ self.branch = None
309
+
310
+ self.hashCode = None
311
+
312
+ def visit(self,
313
+ visitor,
314
+ *arguments,
315
+ **keywords): return visitor.application(self,
316
+ *arguments,
317
+ **keywords)
318
+
319
+ def show(self, isFunction):
320
+ if isFunction:
321
+ return "%s %s" % (self.f.show(True), self.x.show(False))
322
+ else:
323
+ return "(%s %s)" % (self.f.show(True), self.x.show(False))
324
+
325
+ def evaluate(self, environment):
326
+ if self.isConditional:
327
+ if self.branch.evaluate(environment):
328
+ return self.trueBranch.evaluate(environment)
329
+ else:
330
+ return self.falseBranch.evaluate(environment)
331
+ else:
332
+ return self.f.evaluate(environment)(self.x.evaluate(environment))
333
+
334
+ def inferType(self, context, environment, freeVariables):
335
+ (context, ft) = self.f.inferType(context, environment, freeVariables)
336
+ (context, xt) = self.x.inferType(context, environment, freeVariables)
337
+ (context, returnType) = context.makeVariable()
338
+ context = context.unify(ft, arrow(xt, returnType))
339
+ return (context, returnType.apply(context))
340
+
341
+ def applicationParses(self):
342
+ yield self, []
343
+ for f, xs in self.f.applicationParses():
344
+ yield f, xs + [self.x]
345
+
346
+ def applicationParse(self):
347
+ f, xs = self.f.applicationParse()
348
+ return f, xs + [self.x]
349
+
350
+ def shift(self, offset, depth=0):
351
+ return Application(self.f.shift(offset, depth),
352
+ self.x.shift(offset, depth))
353
+
354
+ def substitute(self, old, new):
355
+ if self == old:
356
+ return new
357
+ return Application(
358
+ self.f.substitute(
359
+ old, new), self.x.substitute(
360
+ old, new))
361
+
362
+ def walkUncurried(self, d=0):
363
+ yield d, self
364
+ f, xs = self.applicationParse()
365
+ yield from f.walkUncurried(d)
366
+ for x in xs:
367
+ yield from x.walkUncurried(d)
368
+
369
+ def walk(self, surroundingAbstractions=0):
370
+ yield surroundingAbstractions, self
371
+ yield from self.f.walk(surroundingAbstractions)
372
+ yield from self.x.walk(surroundingAbstractions)
373
+
374
+ def size(self): return self.f.size() + self.x.size()
375
+
376
+ @staticmethod
377
+ def _parse(s,n):
378
+ while n < len(s) and s[n].isspace(): n += 1
379
+ if n == len(s) or s[n] != '(': raise ParseFailure(s)
380
+ n += 1
381
+
382
+ xs = []
383
+ while True:
384
+ x, n = Program._parse(s, n)
385
+ xs.append(x)
386
+ while n < len(s) and s[n].isspace(): n += 1
387
+ if n == len(s):
388
+ raise ParseFailure(s)
389
+ if s[n] == ")":
390
+ n += 1
391
+ break
392
+ e = xs[0]
393
+ for x in xs[1:]:
394
+ e = Application(e, x)
395
+ return e, n
396
+
397
+
398
+ class Index(Program):
399
+ '''
400
+ deBruijn index: https://en.wikipedia.org/wiki/De_Bruijn_index
401
+ These indices encode variables.
402
+ '''
403
+
404
+ def __init__(self, i):
405
+ self.i = i
406
+
407
+ def show(self, isFunction): return "$%d" % self.i
408
+
409
+ def __eq__(self, o): return isinstance(o, Index) and o.i == self.i
410
+
411
+ def __hash__(self): return self.i
412
+
413
+ def visit(self,
414
+ visitor,
415
+ *arguments,
416
+ **keywords): return visitor.index(self,
417
+ *arguments,
418
+ **keywords)
419
+
420
+ def evaluate(self, environment):
421
+ return environment[self.i]
422
+
423
+ def inferType(self, context, environment, freeVariables):
424
+ if self.bound(len(environment)):
425
+ return (context, environment[self.i].apply(context))
426
+ else:
427
+ i = self.i - len(environment)
428
+ if i in freeVariables:
429
+ return (context, freeVariables[i].apply(context))
430
+ context, variable = context.makeVariable()
431
+ freeVariables[i] = variable
432
+ return (context, variable)
433
+
434
+ def clone(self): return Index(self.i)
435
+
436
+ def annotateTypes(self, context, environment):
437
+ self.annotatedType = environment[self.i].applyMutable(context)
438
+
439
+ def shift(self, offset, depth=0):
440
+ # bound variable
441
+ if self.bound(depth):
442
+ return self
443
+ else: # free variable
444
+ i = self.i + offset
445
+ if i < 0:
446
+ raise ShiftFailure()
447
+ return Index(i)
448
+
449
+ def betaReduce(self): return None
450
+
451
+ def isBetaLong(self): return True
452
+
453
+ def freeVariables(self): return {self.i}
454
+
455
+ def substitute(self, old, new):
456
+ if old == self:
457
+ return new
458
+ else:
459
+ return self
460
+
461
+ def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
462
+
463
+ def walkUncurried(self, d=0): yield d, self
464
+
465
+ def size(self): return 1
466
+
467
+ def free(self, surroundingAbstractions):
468
+ '''Is this index a free variable, given that it has surroundingAbstractions lambda's around it?'''
469
+ return self.i >= surroundingAbstractions
470
+
471
+ def bound(self, surroundingAbstractions):
472
+ '''Is this index a bound variable, given that it has surroundingAbstractions lambda's around it?'''
473
+ return self.i < surroundingAbstractions
474
+
475
+ @property
476
+ def isIndex(self): return True
477
+
478
+ @staticmethod
479
+ def _parse(s,n):
480
+ while n < len(s) and s[n].isspace(): n += 1
481
+ if n == len(s) or s[n] != '$':
482
+ raise ParseFailure(s)
483
+ n += 1
484
+ j = ""
485
+ while n < len(s) and s[n].isdigit():
486
+ j += s[n]
487
+ n += 1
488
+ if j == "":
489
+ raise ParseFailure(s)
490
+ return Index(int(j)), n
491
+
492
+
493
+ class Abstraction(Program):
494
+ '''Lambda abstraction. Creates a new function.'''
495
+
496
+ def __init__(self, body):
497
+ self.body = body
498
+ self.hashCode = None
499
+
500
+ @property
501
+ def isAbstraction(self): return True
502
+
503
+ def __eq__(self, o): return isinstance(
504
+ o, Abstraction) and o.body == self.body
505
+
506
+ def __hash__(self):
507
+ if self.hashCode is None:
508
+ self.hashCode = hash((hash(self.body),))
509
+ return self.hashCode
510
+
511
+ """Because Python3 randomizes the hash function, we need to never pickle the hash"""
512
+ def __getstate__(self):
513
+ return self.body
514
+ def __setstate__(self, state):
515
+ self.body = state
516
+ self.hashCode = None
517
+
518
+ def isBetaLong(self): return self.body.isBetaLong()
519
+
520
+ def freeVariables(self):
521
+ return {f - 1 for f in self.body.freeVariables() if f > 0}
522
+
523
+ def visit(self,
524
+ visitor,
525
+ *arguments,
526
+ **keywords): return visitor.abstraction(self,
527
+ *arguments,
528
+ **keywords)
529
+
530
+ def clone(self): return Abstraction(self.body.clone())
531
+
532
+ def annotateTypes(self, context, environment):
533
+ v = context.makeVariable()
534
+ self.body.annotateTypes(context, [v] + environment)
535
+ self.annotatedType = arrow(v.applyMutable(context), self.body.annotatedType)
536
+
537
+ def show(self, isFunction):
538
+ return "(lambda %s)" % (self.body.show(False))
539
+
540
+ def evaluate(self, environment):
541
+ return lambda x: self.body.evaluate([x] + environment)
542
+
543
+ def betaReduce(self):
544
+ b = self.body.betaReduce()
545
+ if b is None: return None
546
+ return Abstraction(b)
547
+
548
+ def inferType(self, context, environment, freeVariables):
549
+ (context, argumentType) = context.makeVariable()
550
+ (context, returnType) = self.body.inferType(
551
+ context, [argumentType] + environment, freeVariables)
552
+ return (context, arrow(argumentType, returnType).apply(context))
553
+
554
+ def shift(self, offset, depth=0):
555
+ return Abstraction(self.body.shift(offset, depth + 1))
556
+
557
+ def substitute(self, old, new):
558
+ if self == old:
559
+ return new
560
+ old = old.shift(1)
561
+ new = new.shift(1)
562
+ return Abstraction(self.body.substitute(old, new))
563
+
564
+ def walk(self, surroundingAbstractions=0):
565
+ yield surroundingAbstractions, self
566
+ yield from self.body.walk(surroundingAbstractions + 1)
567
+
568
+ def walkUncurried(self, d=0):
569
+ yield d, self
570
+ yield from self.body.walkUncurried(d + 1)
571
+
572
+ def size(self): return self.body.size()
573
+
574
+ @staticmethod
575
+ def _parse(s,n):
576
+ n = Program.parseConstant(s,n,
577
+ '(\\','(lambda','(\u03bb')
578
+
579
+ while n < len(s) and s[n].isspace(): n += 1
580
+
581
+ b, n = Program._parse(s,n)
582
+ while n < len(s) and s[n].isspace(): n += 1
583
+ n = Program.parseConstant(s,n,')')
584
+ return Abstraction(b), n
585
+
586
+
587
+ class Primitive(Program):
588
+ GLOBALS = {}
589
+
590
+ def __init__(self, name, ty, value):
591
+ self.tp = ty
592
+ self.name = name
593
+ self.value = value
594
+ if name not in Primitive.GLOBALS:
595
+ Primitive.GLOBALS[name] = self
596
+
597
+ @property
598
+ def isPrimitive(self): return True
599
+
600
+ def __eq__(self, o): return isinstance(
601
+ o, Primitive) and o.name == self.name
602
+
603
+ def __hash__(self): return hash(self.name)
604
+
605
+ def visit(self,
606
+ visitor,
607
+ *arguments,
608
+ **keywords): return visitor.primitive(self,
609
+ *arguments,
610
+ **keywords)
611
+
612
+ def show(self, isFunction): return self.name
613
+
614
+ def clone(self): return Primitive(self.name, self.tp, self.value)
615
+
616
+ def annotateTypes(self, context, environment):
617
+ self.annotatedType = self.tp.instantiateMutable(context)
618
+
619
+ def evaluate(self, environment): return self.value
620
+
621
+ def betaReduce(self): return None
622
+
623
+ def isBetaLong(self): return True
624
+
625
+ def freeVariables(self): return set()
626
+
627
+ def inferType(self, context, environment, freeVariables):
628
+ return self.tp.instantiate(context)
629
+
630
+ def shift(self, offset, depth=0): return self
631
+
632
+ def substitute(self, old, new):
633
+ if self == old:
634
+ return new
635
+ else:
636
+ return self
637
+
638
+ def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
639
+
640
+ def walkUncurried(self, d=0): yield d, self
641
+
642
+ def size(self): return 1
643
+
644
+ @staticmethod
645
+ def _parse(s,n):
646
+ while n < len(s) and s[n].isspace(): n += 1
647
+ name = []
648
+ while n < len(s) and not s[n].isspace() and s[n] not in '()':
649
+ name.append(s[n])
650
+ n += 1
651
+ name = "".join(name)
652
+ if name in Primitive.GLOBALS:
653
+ return Primitive.GLOBALS[name], n
654
+ raise ParseFailure(s)
655
+
656
+ # TODO(@mtensor): needs to be fixed to handle both pickling lambda functions and unpickling in general.
657
+ # def __getstate__(self):
658
+ # return self.name
659
+
660
+ # def __setstate__(self, state):
661
+ # #for backwards compatibility:
662
+ # if type(state) == dict:
663
+ # self.__dict__ = state
664
+ # else:
665
+ # p = Primitive.GLOBALS[state]
666
+ # self.__init__(p.name, p.tp, p.value)
667
+
668
+ class Invented(Program):
669
+ '''New invented primitives'''
670
+
671
+ def __init__(self, body):
672
+ self.body = body
673
+ self.tp = self.body.infer()
674
+ self.hashCode = None
675
+
676
+ @property
677
+ def isInvented(self): return True
678
+
679
+ def show(self, isFunction): return "#%s" % (self.body.show(False))
680
+
681
+ def visit(self,
682
+ visitor,
683
+ *arguments,
684
+ **keywords): return visitor.invented(self,
685
+ *arguments,
686
+ **keywords)
687
+
688
+ def __eq__(self, o): return isinstance(o, Invented) and o.body == self.body
689
+
690
+ def __hash__(self):
691
+ if self.hashCode is None:
692
+ self.hashCode = hash((0, hash(self.body)))
693
+ return self.hashCode
694
+
695
+ """Because Python3 randomizes the hash function, we need to never pickle the hash"""
696
+ def __getstate__(self):
697
+ return self.body, self.tp
698
+ def __setstate__(self, state):
699
+ self.body, self.tp = state
700
+ self.hashCode = None
701
+
702
+ def clone(self): return Invented(self.body)
703
+
704
+ def annotateTypes(self, context, environment):
705
+ self.annotatedType = self.tp.instantiateMutable(context)
706
+
707
+ def evaluate(self, e): return self.body.evaluate([])
708
+
709
+ def betaReduce(self): return self.body
710
+
711
+ def isBetaLong(self): return True
712
+
713
+ def freeVariables(self): return set()
714
+
715
+ def inferType(self, context, environment, freeVariables):
716
+ return self.tp.instantiate(context)
717
+
718
+ def shift(self, offset, depth=0): return self
719
+
720
+ def substitute(self, old, new):
721
+ if self == old:
722
+ return new
723
+ else:
724
+ return self
725
+
726
+ def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
727
+
728
+ def walkUncurried(self, d=0): yield d, self
729
+
730
+ def size(self): return 1
731
+
732
+ @staticmethod
733
+ def _parse(s,n):
734
+ while n < len(s) and s[n].isspace(): n += 1
735
+ if n < len(s) and s[n] == '#':
736
+ n += 1
737
+ b,n = Program._parse(s,n)
738
+ return Invented(b),n
739
+
740
+ raise ParseFailure(s)
741
+
742
+
743
+ class FragmentVariable(Program):
744
+ def __init__(self): pass
745
+
746
+ def show(self, isFunction): return "??"
747
+
748
+ def __eq__(self, o): return isinstance(o, FragmentVariable)
749
+
750
+ def __hash__(self): return 42
751
+
752
+ def visit(self, visitor, *arguments, **keywords):
753
+ return visitor.fragmentVariable(self, *arguments, **keywords)
754
+
755
+ def evaluate(self, e):
756
+ raise Exception('Attempt to evaluate fragment variable')
757
+
758
+ def betaReduce(self):
759
+ raise Exception('Attempt to beta reduce fragment variable')
760
+
761
+ def inferType(self, context, environment, freeVariables):
762
+ return context.makeVariable()
763
+
764
+ def shift(self, offset, depth=0):
765
+ raise Exception('Attempt to shift fragment variable')
766
+
767
+ def substitute(self, old, new):
768
+ if self == old:
769
+ return new
770
+ else:
771
+ return self
772
+
773
+ def match(
774
+ self,
775
+ context,
776
+ expression,
777
+ holes,
778
+ variableBindings,
779
+ environment=[]):
780
+ surroundingAbstractions = len(environment)
781
+ try:
782
+ context, variable = context.makeVariable()
783
+ holes.append(
784
+ (variable, expression.shift(-surroundingAbstractions)))
785
+ return context, variable
786
+ except ShiftFailure:
787
+ raise MatchFailure()
788
+
789
+ def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
790
+
791
+ def walkUncurried(self, d=0): yield d, self
792
+
793
+ def size(self): return 1
794
+
795
+ @staticmethod
796
+ def _parse(s,n):
797
+ while n < len(s) and s[n].isspace(): n += 1
798
+ n = Program.parseConstant(s,n,'??','?')
799
+ return FragmentVariable.single, n
800
+
801
+ FragmentVariable.single = FragmentVariable()
802
+
803
+
804
+ class Hole(Program):
805
+ def __init__(self): pass
806
+
807
+ def show(self, isFunction): return "<HOLE>"
808
+
809
+ @property
810
+ def isHole(self): return True
811
+
812
+ def __eq__(self, o): return isinstance(o, Hole)
813
+
814
+ def __hash__(self): return 42
815
+
816
+ def evaluate(self, e):
817
+ raise Exception('Attempt to evaluate hole')
818
+
819
+ def betaReduce(self):
820
+ raise Exception('Attempt to beta reduce hole')
821
+
822
+ def inferType(self, context, environment, freeVariables):
823
+ return context.makeVariable()
824
+
825
+ def shift(self, offset, depth=0):
826
+ raise Exception('Attempt to shift fragment variable')
827
+
828
+ def walk(self, surroundingAbstractions=0): yield surroundingAbstractions, self
829
+
830
+ def walkUncurried(self, d=0): yield d, self
831
+
832
+ def size(self): return 1
833
+
834
+ @staticmethod
835
+ def _parse(s,n):
836
+ while n < len(s) and s[n].isspace(): n += 1
837
+ n = Program.parseConstant(s,n,
838
+ '<HOLE>')
839
+ return Hole.single, n
840
+
841
+
842
+ Hole.single = Hole()
843
+
844
+
845
+ class ShareVisitor(object):
846
+ def __init__(self):
847
+ self.primitiveTable = {}
848
+ self.inventedTable = {}
849
+ self.indexTable = {}
850
+ self.applicationTable = {}
851
+ self.abstractionTable = {}
852
+
853
+ def invented(self, e):
854
+ body = e.body.visit(self)
855
+ i = id(body)
856
+ if i in self.inventedTable:
857
+ return self.inventedTable[i]
858
+ new = Invented(body)
859
+ self.inventedTable[i] = new
860
+ return new
861
+
862
+ def primitive(self, e):
863
+ if e.name in self.primitiveTable:
864
+ return self.primitiveTable[e.name]
865
+ self.primitiveTable[e.name] = e
866
+ return e
867
+
868
+ def index(self, e):
869
+ if e.i in self.indexTable:
870
+ return self.indexTable[e.i]
871
+ self.indexTable[e.i] = e
872
+ return e
873
+
874
+ def application(self, e):
875
+ f = e.f.visit(self)
876
+ x = e.x.visit(self)
877
+ fi = id(f)
878
+ xi = id(x)
879
+ i = (fi, xi)
880
+ if i in self.applicationTable:
881
+ return self.applicationTable[i]
882
+ new = Application(f, x)
883
+ self.applicationTable[i] = new
884
+ return new
885
+
886
+ def abstraction(self, e):
887
+ body = e.body.visit(self)
888
+ i = id(body)
889
+ if i in self.abstractionTable:
890
+ return self.abstractionTable[i]
891
+ new = Abstraction(body)
892
+ self.abstractionTable[i] = new
893
+ return new
894
+
895
+ def execute(self, e):
896
+ return e.visit(self)
897
+
898
+
899
+ class Mutator:
900
+ """Perform local mutations to an expr, yielding the expr and the
901
+ description length distance from the original program"""
902
+
903
+ def __init__(self, grammar, fn):
904
+ """Fn yields (expression, loglikelihood) from a type and loss.
905
+ Therefore, loss+loglikelihood is the distance from the original program."""
906
+ self.fn = fn
907
+ self.grammar = grammar
908
+ self.history = []
909
+
910
+ def enclose(self, expr):
911
+ for h in self.history[::-1]:
912
+ expr = h(expr)
913
+ return expr
914
+
915
+ def invented(self, e, tp, env, is_lhs=False):
916
+ deleted_ll = self.logLikelihood(tp, e, env)
917
+ for expr, replaced_ll in self.fn(tp, deleted, is_left_application=is_lhs):
918
+ yield self.enclose(expr), deleted_ll + replaced_ll
919
+
920
+ def primitive(self, e, tp, env, is_lhs=False):
921
+ deleted_ll = self.logLikelihood(tp, e, env)
922
+ for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
923
+ yield self.enclose(expr), deleted_ll + replaced_ll
924
+
925
+ def index(self, e, tp, env, is_lhs=False):
926
+ #yield from ()
927
+ deleted_ll = self.logLikelihood(tp, e, env) #self.grammar.logVariable
928
+ for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
929
+ yield self.enclose(expr), deleted_ll + replaced_ll
930
+
931
+ def application(self, e, tp, env, is_lhs=False):
932
+ self.history.append(lambda expr: Application(expr, e.x))
933
+ f_tp = arrow(e.x.infer(), tp)
934
+ yield from e.f.visit(self, f_tp, env, is_lhs=True)
935
+ self.history[-1] = lambda expr: Application(e.f, expr)
936
+ x_tp = inferArg(tp, e.f.infer())
937
+ yield from e.x.visit(self, x_tp, env)
938
+ self.history.pop()
939
+ deleted_ll = self.logLikelihood(tp, e, env)
940
+ for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
941
+ yield self.enclose(expr), deleted_ll + replaced_ll
942
+
943
+ def abstraction(self, e, tp, env, is_lhs=False):
944
+ self.history.append(lambda expr: Abstraction(expr))
945
+ yield from e.body.visit(self, tp.arguments[1], [tp.arguments[0]]+env)
946
+ self.history.pop()
947
+ deleted_ll = self.logLikelihood(tp, e, env)
948
+ for expr, replaced_ll in self.fn(tp, deleted_ll, is_left_application=is_lhs):
949
+ yield self.enclose(expr), deleted_ll + replaced_ll
950
+
951
+ def execute(self, e, tp):
952
+ yield from e.visit(self, tp, [])
953
+
954
+ def logLikelihood(self, tp, e, env):
955
+ summary = None
956
+ try:
957
+ _, summary = self.grammar.likelihoodSummary(Context.EMPTY, env,
958
+ tp, e, silent=True)
959
+ except AssertionError as err:
960
+ #print(f"closedLikelihoodSummary failed on tp={tp}, e={e}, error={err}")
961
+ pass
962
+ if summary is not None:
963
+ return summary.logLikelihood(self.grammar)
964
+ else:
965
+ tmpE, depth = e, 0
966
+ while isinstance(tmpE, Abstraction):
967
+ depth += 1
968
+ tmpE = tmpE.body
969
+ to_introduce = len(tp.functionArguments()) - depth
970
+ if to_introduce == 0:
971
+ #print(f"HIT NEGATIVEINFINITY, tp={tp}, e={e}")
972
+ return NEGATIVEINFINITY
973
+ for i in reversed(range(to_introduce)):
974
+ e = Application(e, Index(i))
975
+ for _ in range(to_introduce):
976
+ e = Abstraction(e)
977
+ return self.logLikelihood(tp, e, env)
978
+
979
+
980
+ class RegisterPrimitives(object):
981
+ def invented(self, e): e.body.visit(self)
982
+
983
+ def primitive(self, e):
984
+ if e.name not in Primitive.GLOBALS:
985
+ Primitive(e.name, e.tp, e.value)
986
+
987
+ def index(self, e): pass
988
+
989
+ def application(self, e):
990
+ e.f.visit(self)
991
+ e.x.visit(self)
992
+
993
+ def abstraction(self, e): e.body.visit(self)
994
+
995
+ @staticmethod
996
+ def register(e): e.visit(RegisterPrimitives())
997
+
998
+
999
+ class PrettyVisitor(object):
1000
+ def __init__(self, Lisp=False):
1001
+ self.Lisp = Lisp
1002
+ self.numberOfVariables = 0
1003
+ self.freeVariables = {}
1004
+
1005
+ self.variableNames = ["x", "y", "z", "u", "v", "w"]
1006
+ self.variableNames += [chr(ord('a') + j)
1007
+ for j in range(20)]
1008
+ self.toplevel = True
1009
+
1010
+ def makeVariable(self):
1011
+ v = self.variableNames[self.numberOfVariables]
1012
+ self.numberOfVariables += 1
1013
+ return v
1014
+
1015
+ def invented(self, e, environment, isFunction, isAbstraction):
1016
+ s = e.body.visit(self, [], isFunction, isAbstraction)
1017
+ return s
1018
+
1019
+ def primitive(self, e, environment, isVariable, isAbstraction): return e.name
1020
+
1021
+ def index(self, e, environment, isVariable, isAbstraction):
1022
+ if e.i < len(environment):
1023
+ return environment[e.i]
1024
+ else:
1025
+ i = e.i - len(environment)
1026
+ if i in self.freeVariables:
1027
+ return self.freeVariables[i]
1028
+ else:
1029
+ v = self.makeVariable()
1030
+ self.freeVariables[i] = v
1031
+ return v
1032
+
1033
+ def application(self, e, environment, isFunction, isAbstraction):
1034
+ self.toplevel = False
1035
+ s = "%s %s" % (e.f.visit(self, environment, True, False),
1036
+ e.x.visit(self, environment, False, False))
1037
+ if isFunction:
1038
+ return s
1039
+ else:
1040
+ return "(" + s + ")"
1041
+
1042
+ def abstraction(self, e, environment, isFunction, isAbstraction):
1043
+ toplevel = self.toplevel
1044
+ self.toplevel = False
1045
+ if not self.Lisp:
1046
+ # Invent a new variable
1047
+ v = self.makeVariable()
1048
+ body = e.body.visit(self,
1049
+ [v] + environment,
1050
+ False,
1051
+ True)
1052
+ if not e.body.isAbstraction:
1053
+ body = "." + body
1054
+ body = v + body
1055
+ if not isAbstraction:
1056
+ body = "λ" + body
1057
+ if not toplevel:
1058
+ body = "(%s)" % body
1059
+ return body
1060
+ else:
1061
+ child = e
1062
+ newVariables = []
1063
+ while child.isAbstraction:
1064
+ newVariables = [self.makeVariable()] + newVariables
1065
+ child = child.body
1066
+ body = child.visit(self, newVariables + environment,
1067
+ False, True)
1068
+ body = "(λ (%s) %s)"%(" ".join(reversed(newVariables)), body)
1069
+ return body
1070
+
1071
+
1072
+
1073
+ def prettyProgram(e, Lisp=False):
1074
+ return e.visit(PrettyVisitor(Lisp=Lisp), [], False, False)
1075
+
1076
+ class EtaExpandFailure(Exception): pass
1077
+ class EtaLongVisitor(object):
1078
+ """Converts an expression into eta-longform"""
1079
+ def __init__(self, request=None):
1080
+ self.request = request
1081
+ self.context = None
1082
+
1083
+ def makeLong(self, e, request):
1084
+ if request.isArrow():
1085
+ # eta expansion
1086
+ return Abstraction(Application(e.shift(1),
1087
+ Index(0)))
1088
+ return None
1089
+
1090
+
1091
+ def abstraction(self, e, request, environment):
1092
+ if not request.isArrow(): raise EtaExpandFailure()
1093
+
1094
+ return Abstraction(e.body.visit(self,
1095
+ request.arguments[1],
1096
+ [request.arguments[0]] + environment))
1097
+
1098
+ def _application(self, e, request, environment):
1099
+ l = self.makeLong(e, request)
1100
+ if l is not None: return l.visit(self, request, environment)
1101
+
1102
+ f, xs = e.applicationParse()
1103
+
1104
+ if f.isIndex:
1105
+ ft = environment[f.i].applyMutable(self.context)
1106
+ elif f.isInvented or f.isPrimitive:
1107
+ ft = f.tp.instantiateMutable(self.context)
1108
+ else: assert False, "Not in beta long form: %s"%e
1109
+
1110
+ self.context.unify(request, ft.returns())
1111
+ ft = ft.applyMutable(self.context)
1112
+
1113
+ xt = ft.functionArguments()
1114
+ if len(xs) != len(xt): raise EtaExpandFailure()
1115
+
1116
+ returnValue = f
1117
+ for x,t in zip(xs,xt):
1118
+ t = t.applyMutable(self.context)
1119
+ returnValue = Application(returnValue,
1120
+ x.visit(self, t, environment))
1121
+ return returnValue
1122
+
1123
+ # This procedure works by recapitulating the generative process
1124
+ # applications indices and primitives are all generated identically
1125
+
1126
+ def application(self, e, request, environment): return self._application(e, request, environment)
1127
+
1128
+ def index(self, e, request, environment): return self._application(e, request, environment)
1129
+
1130
+ def primitive(self, e, request, environment): return self._application(e, request, environment)
1131
+
1132
+ def invented(self, e, request, environment): return self._application(e, request, environment)
1133
+
1134
+ def execute(self, e):
1135
+ assert len(e.freeVariables()) == 0
1136
+
1137
+ if self.request is None:
1138
+ eprint("WARNING: request not specified for etaexpansion")
1139
+ self.request = e.infer()
1140
+ self.context = MutableContext()
1141
+ el = e.visit(self, self.request, [])
1142
+ self.context = None
1143
+ # assert el.infer().canonical() == e.infer().canonical(), \
1144
+ # f"Types are not preserved by ETA expansion: {e} : {e.infer().canonical()} vs {el} : {el.infer().canonical()}"
1145
+ return el
1146
+
1147
+
1148
+
1149
+ class StripPrimitiveVisitor():
1150
+ """Replaces all primitives .value's w/ None. Does not destructively modify anything"""
1151
+ def invented(self,e):
1152
+ return Invented(e.body.visit(self))
1153
+ def primitive(self,e):
1154
+ return Primitive(e.name,e.tp,None)
1155
+ def application(self,e):
1156
+ return Application(e.f.visit(self),
1157
+ e.x.visit(self))
1158
+ def abstraction(self,e):
1159
+ return Abstraction(e.body.visit(self))
1160
+ def index(self,e): return e
1161
+
1162
+ class ReplacePrimitiveValueVisitor():
1163
+ """Intended to be used after StripPrimitiveVisitor.
1164
+ Replaces all primitive.value's with their corresponding entry in Primitive.GLOBALS"""
1165
+ def invented(self,e):
1166
+ return Invented(e.body.visit(self))
1167
+ def primitive(self,e):
1168
+ return Primitive(e.name,e.tp,Primitive.GLOBALS[e.name].value)
1169
+ def application(self,e):
1170
+ return Application(e.f.visit(self),
1171
+ e.x.visit(self))
1172
+ def abstraction(self,e):
1173
+ return Abstraction(e.body.visit(self))
1174
+ def index(self,e): return e
1175
+
1176
+ def strip_primitive_values(e):
1177
+ return e.visit(StripPrimitiveVisitor())
1178
+ def unstrip_primitive_values(e):
1179
+ return e.visit(ReplacePrimitiveValueVisitor())
1180
+
1181
+
1182
+ # from luke
1183
+ class TokeniseVisitor(object):
1184
+ def invented(self, e):
1185
+ return [e.body]
1186
+
1187
+ def primitive(self, e): return [e.name]
1188
+
1189
+ def index(self, e):
1190
+ return ["$" + str(e.i)]
1191
+
1192
+ def application(self, e):
1193
+ return ["("] + e.f.visit(self) + e.x.visit(self) + [")"]
1194
+
1195
+ def abstraction(self, e):
1196
+ return ["(_lambda"] + e.body.visit(self) + [")_lambda"]
1197
+
1198
+
1199
+ def tokeniseProgram(e):
1200
+ return e.visit(TokeniseVisitor())
1201
+
1202
+
1203
+ def untokeniseProgram(l):
1204
+ lookup = {
1205
+ "(_lambda": "(lambda",
1206
+ ")_lambda": ")"
1207
+ }
1208
+ s = " ".join(lookup.get(x, x) for x in l)
1209
+ return Program.parse(s)
1210
+
1211
+ if __name__ == "__main__":
1212
+ from dreamcoder.domains.arithmetic.arithmeticPrimitives import *
1213
+ e = Program.parse("(#(lambda (?? (+ 1 $0))) (lambda (?? (+ 1 $0))) (lambda (?? (+ 1 $0))) - * (+ +))")
1214
+ eprint(e)
dreamcoder/recognition.py ADDED
@@ -0,0 +1,1528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.enumeration import *
2
+ from dreamcoder.grammar import *
3
+ # luke
4
+
5
+
6
+ import gc
7
+
8
+ try:
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.autograd import Variable
13
+ from torch.nn.utils.rnn import pack_padded_sequence
14
+ except:
15
+ eprint("WARNING: Could not import torch. This is only okay when doing pypy compression.")
16
+
17
+ try:
18
+ import numpy as np
19
+ except:
20
+ eprint("WARNING: Could not import np. This is only okay when doing pypy compression.")
21
+
22
+ import json
23
+
24
+
25
+ def variable(x, volatile=False, cuda=False):
26
+ if isinstance(x, list):
27
+ x = np.array(x)
28
+ if isinstance(x, (np.ndarray, np.generic)):
29
+ x = torch.from_numpy(x)
30
+ if cuda:
31
+ x = x.cuda()
32
+ return Variable(x, volatile=volatile)
33
+
34
+ def maybe_cuda(x, use_cuda):
35
+ if use_cuda:
36
+ return x.cuda()
37
+ else:
38
+ return x
39
+
40
+
41
+ def is_torch_not_a_number(v):
42
+ """checks whether a tortured variable is nan"""
43
+ v = v.data
44
+ if not ((v == v).item()):
45
+ return True
46
+ return False
47
+
48
+ def is_torch_invalid(v):
49
+ """checks whether a torch variable is nan or inf"""
50
+ if is_torch_not_a_number(v):
51
+ return True
52
+ a = v - v
53
+ if is_torch_not_a_number(a):
54
+ return True
55
+ return False
56
+
57
+ def _relu(x): return x.clamp(min=0)
58
+
59
+ class Entropy(nn.Module):
60
+ """Calculates the entropy of logits"""
61
+ def __init__(self):
62
+ super(Entropy, self).__init__()
63
+
64
+ def forward(self, x):
65
+ b = F.softmax(x, dim=0) * F.log_softmax(x, dim=0)
66
+ b = -1.0 * b.sum()
67
+ return b
68
+
69
+ class GrammarNetwork(nn.Module):
70
+ """Neural network that outputs a grammar"""
71
+ def __init__(self, inputDimensionality, grammar):
72
+ super(GrammarNetwork, self).__init__()
73
+ self.logProductions = nn.Linear(inputDimensionality, len(grammar)+1)
74
+ self.grammar = grammar
75
+
76
+ def forward(self, x):
77
+ """Takes as input inputDimensionality-dimensional vector and returns Grammar
78
+ Tensor-valued probabilities"""
79
+ logProductions = self.logProductions(x)
80
+ return Grammar(logProductions[-1].view(1), #logVariable
81
+ [(logProductions[k].view(1), t, program)
82
+ for k, (_, t, program) in enumerate(self.grammar.productions)],
83
+ continuationType=self.grammar.continuationType)
84
+
85
+ def batchedLogLikelihoods(self, xs, summaries):
86
+ """Takes as input BxinputDimensionality vector & B likelihood summaries;
87
+ returns B-dimensional vector containing log likelihood of each summary"""
88
+ use_cuda = xs.device.type == 'cuda'
89
+
90
+ B = xs.size(0)
91
+ assert len(summaries) == B
92
+ logProductions = self.logProductions(xs)
93
+
94
+ # uses[b][p] is # uses of primitive p by summary b
95
+ uses = np.zeros((B,len(self.grammar) + 1))
96
+ for b,summary in enumerate(summaries):
97
+ for p, production in enumerate(self.grammar.primitives):
98
+ uses[b,p] = summary.uses.get(production, 0.)
99
+ uses[b,len(self.grammar)] = summary.uses.get(Index(0), 0)
100
+
101
+ numerator = (logProductions * maybe_cuda(torch.from_numpy(uses).float(),use_cuda)).sum(1)
102
+ numerator += maybe_cuda(torch.tensor([summary.constant for summary in summaries ]).float(), use_cuda)
103
+
104
+ alternativeSet = {normalizer
105
+ for s in summaries
106
+ for normalizer in s.normalizers }
107
+ alternativeSet = list(alternativeSet)
108
+
109
+ mask = np.zeros((len(alternativeSet), len(self.grammar) + 1))
110
+ for tau in range(len(alternativeSet)):
111
+ for p, production in enumerate(self.grammar.primitives):
112
+ mask[tau,p] = 0. if production in alternativeSet[tau] else NEGATIVEINFINITY
113
+ mask[tau,len(self.grammar)] = 0. if Index(0) in alternativeSet[tau] else NEGATIVEINFINITY
114
+ mask = maybe_cuda(torch.tensor(mask).float(), use_cuda)
115
+
116
+ # mask: Rx|G|
117
+ # logProductions: Bx|G|
118
+ # Want: mask + logProductions : BxRx|G| = z
119
+ z = mask.repeat(B,1,1) + logProductions.repeat(len(alternativeSet),1,1).transpose(1,0)
120
+ # z: BxR
121
+ z = torch.logsumexp(z, 2) # pytorch 1.0 dependency
122
+
123
+ # Calculate how many times each normalizer was used
124
+ N = np.zeros((B, len(alternativeSet)))
125
+ for b, summary in enumerate(summaries):
126
+ for tau, alternatives in enumerate(alternativeSet):
127
+ N[b, tau] = summary.normalizers.get(alternatives,0.)
128
+
129
+ denominator = (maybe_cuda(torch.tensor(N).float(),use_cuda) * z).sum(1)
130
+ return numerator - denominator
131
+
132
+
133
+
134
+ class ContextualGrammarNetwork_LowRank(nn.Module):
135
+ def __init__(self, inputDimensionality, grammar, R=16):
136
+ """Low-rank approximation to bigram model. Parameters is linear in number of primitives.
137
+ R: maximum rank"""
138
+
139
+ super(ContextualGrammarNetwork_LowRank, self).__init__()
140
+
141
+ self.grammar = grammar
142
+
143
+ self.R = R # embedding size
144
+
145
+ # library now just contains a list of indicies which go with each primitive
146
+ self.grammar = grammar
147
+ self.library = {}
148
+ self.n_grammars = 0
149
+ for prim in grammar.primitives:
150
+ numberOfArguments = len(prim.infer().functionArguments())
151
+ idx_list = list(range(self.n_grammars, self.n_grammars+numberOfArguments))
152
+ self.library[prim] = idx_list
153
+ self.n_grammars += numberOfArguments
154
+
155
+ # We had an extra grammar for when there is no parent and for when the parent is a variable
156
+ self.n_grammars += 2
157
+ self.transitionMatrix = LowRank(inputDimensionality, self.n_grammars, len(grammar) + 1, R)
158
+
159
+ def grammarFromVector(self, logProductions):
160
+ return Grammar(logProductions[-1].view(1),
161
+ [(logProductions[k].view(1), t, program)
162
+ for k, (_, t, program) in enumerate(self.grammar.productions)],
163
+ continuationType=self.grammar.continuationType)
164
+
165
+ def forward(self, x):
166
+ assert len(x.size()) == 1, "contextual grammar doesn't currently support batching"
167
+
168
+ transitionMatrix = self.transitionMatrix(x)
169
+
170
+ return ContextualGrammar(self.grammarFromVector(transitionMatrix[-1]), self.grammarFromVector(transitionMatrix[-2]),
171
+ {prim: [self.grammarFromVector(transitionMatrix[j]) for j in js]
172
+ for prim, js in self.library.items()} )
173
+
174
+ def vectorizedLogLikelihoods(self, x, summaries):
175
+ B = len(summaries)
176
+ G = len(self.grammar) + 1
177
+
178
+ # Which column of the transition matrix corresponds to which primitive
179
+ primitiveColumn = {p: c
180
+ for c, (_1,_2,p) in enumerate(self.grammar.productions) }
181
+ primitiveColumn[Index(0)] = G - 1
182
+ # Which row of the transition matrix corresponds to which context
183
+ contextRow = {(parent, index): r
184
+ for parent, indices in self.library.items()
185
+ for index, r in enumerate(indices) }
186
+ contextRow[(None,None)] = self.n_grammars - 1
187
+ contextRow[(Index(0),None)] = self.n_grammars - 2
188
+
189
+ transitionMatrix = self.transitionMatrix(x)
190
+
191
+ # uses[b][g][p] is # uses of primitive p by summary b for parent g
192
+ uses = np.zeros((B,self.n_grammars,len(self.grammar)+1))
193
+ for b,summary in enumerate(summaries):
194
+ for e, ss in summary.library.items():
195
+ for g,s in zip(self.library[e], ss):
196
+ assert g < self.n_grammars - 2
197
+ for p, production in enumerate(self.grammar.primitives):
198
+ uses[b,g,p] = s.uses.get(production, 0.)
199
+ uses[b,g,len(self.grammar)] = s.uses.get(Index(0), 0)
200
+
201
+ # noParent: this is the last network output
202
+ for p, production in enumerate(self.grammar.primitives):
203
+ uses[b, self.n_grammars - 1, p] = summary.noParent.uses.get(production, 0.)
204
+ uses[b, self.n_grammars - 1, G - 1] = summary.noParent.uses.get(Index(0), 0.)
205
+
206
+ # variableParent: this is the penultimate network output
207
+ for p, production in enumerate(self.grammar.primitives):
208
+ uses[b, self.n_grammars - 2, p] = summary.variableParent.uses.get(production, 0.)
209
+ uses[b, self.n_grammars - 2, G - 1] = summary.variableParent.uses.get(Index(0), 0.)
210
+
211
+ uses = maybe_cuda(torch.tensor(uses).float(),use_cuda)
212
+ numerator = uses.view(B, -1) @ transitionMatrix.view(-1)
213
+
214
+ constant = np.zeros(B)
215
+ for b,summary in enumerate(summaries):
216
+ constant[b] += summary.noParent.constant + summary.variableParent.constant
217
+ for ss in summary.library.values():
218
+ for s in ss:
219
+ constant[b] += s.constant
220
+
221
+ numerator = numerator + maybe_cuda(torch.tensor(constant).float(),use_cuda)
222
+
223
+ # Calculate the god-awful denominator
224
+ # Map from (parent, index, {set-of-alternatives}) to [occurrences-in-summary-zero, occurrences-in-summary-one, ...]
225
+ alternativeSet = {}
226
+ for b,summary in enumerate(summaries):
227
+ for normalizer, frequency in summary.noParent.normalizers.items():
228
+ k = (None,None,normalizer)
229
+ alternativeSet[k] = alternativeSet.get(k, np.zeros(B))
230
+ alternativeSet[k][b] += frequency
231
+ for normalizer, frequency in summary.variableParent.normalizers.items():
232
+ k = (Index(0),None,normalizer)
233
+ alternativeSet[k] = alternativeSet.get(k, np.zeros(B))
234
+ alternativeSet[k][b] += frequency
235
+ for parent, ss in summary.library.items():
236
+ for argumentIndex, s in enumerate(ss):
237
+ for normalizer, frequency in s.normalizers.items():
238
+ k = (parent, argumentIndex, normalizer)
239
+ alternativeSet[k] = alternativeSet.get(k, zeros(B))
240
+ alternativeSet[k][b] += frequency
241
+
242
+ # Calculate each distinct normalizing constant
243
+ alternativeNormalizer = {}
244
+ for parent, index, alternatives in alternativeSet:
245
+ r = transitionMatrix[contextRow[(parent, index)]]
246
+ entries = r[ [primitiveColumn[alternative] for alternative in alternatives ]]
247
+ alternativeNormalizer[(parent, index, alternatives)] = torch.logsumexp(entries, dim=0)
248
+
249
+ # Concatenate the normalizers into a vector
250
+ normalizerKeys = list(alternativeSet.keys())
251
+ normalizerVector = torch.cat([ alternativeNormalizer[k] for k in normalizerKeys])
252
+
253
+ assert False, "This function is still in progress."
254
+
255
+
256
+ def batchedLogLikelihoods(self, xs, summaries):
257
+ """Takes as input BxinputDimensionality vector & B likelihood summaries;
258
+ returns B-dimensional vector containing log likelihood of each summary"""
259
+ use_cuda = xs.device.type == 'cuda'
260
+
261
+ B = xs.shape[0]
262
+ G = len(self.grammar) + 1
263
+ assert len(summaries) == B
264
+
265
+ # logProductions: Bx n_grammars x G
266
+ logProductions = self.transitionMatrix(xs)
267
+ # uses[b][g][p] is # uses of primitive p by summary b for parent g
268
+ uses = np.zeros((B,self.n_grammars,len(self.grammar)+1))
269
+ for b,summary in enumerate(summaries):
270
+ for e, ss in summary.library.items():
271
+ for g,s in zip(self.library[e], ss):
272
+ assert g < self.n_grammars - 2
273
+ for p, production in enumerate(self.grammar.primitives):
274
+ uses[b,g,p] = s.uses.get(production, 0.)
275
+ uses[b,g,len(self.grammar)] = s.uses.get(Index(0), 0)
276
+
277
+ # noParent: this is the last network output
278
+ for p, production in enumerate(self.grammar.primitives):
279
+ uses[b, self.n_grammars - 1, p] = summary.noParent.uses.get(production, 0.)
280
+ uses[b, self.n_grammars - 1, G - 1] = summary.noParent.uses.get(Index(0), 0.)
281
+
282
+ # variableParent: this is the penultimate network output
283
+ for p, production in enumerate(self.grammar.primitives):
284
+ uses[b, self.n_grammars - 2, p] = summary.variableParent.uses.get(production, 0.)
285
+ uses[b, self.n_grammars - 2, G - 1] = summary.variableParent.uses.get(Index(0), 0.)
286
+
287
+ numerator = (logProductions*maybe_cuda(torch.tensor(uses).float(),use_cuda)).view(B,-1).sum(1)
288
+
289
+ constant = np.zeros(B)
290
+ for b,summary in enumerate(summaries):
291
+ constant[b] += summary.noParent.constant + summary.variableParent.constant
292
+ for ss in summary.library.values():
293
+ for s in ss:
294
+ constant[b] += s.constant
295
+
296
+ numerator += maybe_cuda(torch.tensor(constant).float(),use_cuda)
297
+
298
+ if True:
299
+
300
+ # Calculate the god-awful denominator
301
+ alternativeSet = set()
302
+ for summary in summaries:
303
+ for normalizer in summary.noParent.normalizers: alternativeSet.add(normalizer)
304
+ for normalizer in summary.variableParent.normalizers: alternativeSet.add(normalizer)
305
+ for ss in summary.library.values():
306
+ for s in ss:
307
+ for normalizer in s.normalizers: alternativeSet.add(normalizer)
308
+ alternativeSet = list(alternativeSet)
309
+
310
+ mask = np.zeros((len(alternativeSet), G))
311
+ for tau in range(len(alternativeSet)):
312
+ for p, production in enumerate(self.grammar.primitives):
313
+ mask[tau,p] = 0. if production in alternativeSet[tau] else NEGATIVEINFINITY
314
+ mask[tau, G - 1] = 0. if Index(0) in alternativeSet[tau] else NEGATIVEINFINITY
315
+ mask = maybe_cuda(torch.tensor(mask).float(), use_cuda)
316
+
317
+ z = mask.repeat(self.n_grammars,1,1).repeat(B,1,1,1) + \
318
+ logProductions.repeat(len(alternativeSet),1,1,1).transpose(0,1).transpose(1,2)
319
+ z = torch.logsumexp(z, 3) # pytorch 1.0 dependency
320
+
321
+ N = np.zeros((B, self.n_grammars, len(alternativeSet)))
322
+ for b, summary in enumerate(summaries):
323
+ for e, ss in summary.library.items():
324
+ for g,s in zip(self.library[e], ss):
325
+ assert g < self.n_grammars - 2
326
+ for r, alternatives in enumerate(alternativeSet):
327
+ N[b,g,r] = s.normalizers.get(alternatives, 0.)
328
+ # noParent: this is the last network output
329
+ for r, alternatives in enumerate(alternativeSet):
330
+ N[b,self.n_grammars - 1,r] = summary.noParent.normalizers.get(alternatives, 0.)
331
+ # variableParent: this is the penultimate network output
332
+ for r, alternatives in enumerate(alternativeSet):
333
+ N[b,self.n_grammars - 2,r] = summary.variableParent.normalizers.get(alternatives, 0.)
334
+ N = maybe_cuda(torch.tensor(N).float(), use_cuda)
335
+ denominator = (N*z).sum(1).sum(1)
336
+ else:
337
+ gs = [ self(xs[b]) for b in range(B) ]
338
+ denominator = torch.cat([ summary.denominator(g) for summary,g in zip(summaries, gs) ])
339
+
340
+
341
+
342
+
343
+
344
+ ll = numerator - denominator
345
+
346
+ if False: # verifying that batching works correctly
347
+ gs = [ self(xs[b]) for b in range(B) ]
348
+ _l = torch.cat([ summary.logLikelihood(g) for summary,g in zip(summaries, gs) ])
349
+ assert torch.all((ll - _l).abs() < 0.0001)
350
+ return ll
351
+
352
+ class ContextualGrammarNetwork_Mask(nn.Module):
353
+ def __init__(self, inputDimensionality, grammar):
354
+ """Bigram model, but where the bigram transitions are unconditional.
355
+ Individual primitive probabilities are still conditional (predicted by neural network)
356
+ """
357
+
358
+ super(ContextualGrammarNetwork_Mask, self).__init__()
359
+
360
+ self.grammar = grammar
361
+
362
+ # library now just contains a list of indicies which go with each primitive
363
+ self.grammar = grammar
364
+ self.library = {}
365
+ self.n_grammars = 0
366
+ for prim in grammar.primitives:
367
+ numberOfArguments = len(prim.infer().functionArguments())
368
+ idx_list = list(range(self.n_grammars, self.n_grammars+numberOfArguments))
369
+ self.library[prim] = idx_list
370
+ self.n_grammars += numberOfArguments
371
+
372
+ # We had an extra grammar for when there is no parent and for when the parent is a variable
373
+ self.n_grammars += 2
374
+ self._transitionMatrix = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(self.n_grammars, len(grammar) + 1)))
375
+ self._logProductions = nn.Linear(inputDimensionality, len(grammar)+1)
376
+
377
+ def transitionMatrix(self, x):
378
+ if len(x.shape) == 1: # not batched
379
+ return self._logProductions(x) + self._transitionMatrix # will broadcast
380
+ elif len(x.shape) == 2: # batched
381
+ return self._logProductions(x).unsqueeze(1).repeat(1,self.n_grammars,1) + \
382
+ self._transitionMatrix.unsqueeze(0).repeat(x.size(0),1,1)
383
+ else:
384
+ assert False, "unknown shape for transition matrix input"
385
+
386
+ def grammarFromVector(self, logProductions):
387
+ return Grammar(logProductions[-1].view(1),
388
+ [(logProductions[k].view(1), t, program)
389
+ for k, (_, t, program) in enumerate(self.grammar.productions)],
390
+ continuationType=self.grammar.continuationType)
391
+
392
+ def forward(self, x):
393
+ assert len(x.size()) == 1, "contextual grammar doesn't currently support batching"
394
+
395
+ transitionMatrix = self.transitionMatrix(x)
396
+
397
+ return ContextualGrammar(self.grammarFromVector(transitionMatrix[-1]), self.grammarFromVector(transitionMatrix[-2]),
398
+ {prim: [self.grammarFromVector(transitionMatrix[j]) for j in js]
399
+ for prim, js in self.library.items()} )
400
+
401
+ def batchedLogLikelihoods(self, xs, summaries):
402
+ """Takes as input BxinputDimensionality vector & B likelihood summaries;
403
+ returns B-dimensional vector containing log likelihood of each summary"""
404
+ use_cuda = xs.device.type == 'cuda'
405
+
406
+ B = xs.shape[0]
407
+ G = len(self.grammar) + 1
408
+ assert len(summaries) == B
409
+
410
+ # logProductions: Bx n_grammars x G
411
+ logProductions = self.transitionMatrix(xs)
412
+ # uses[b][g][p] is # uses of primitive p by summary b for parent g
413
+ uses = np.zeros((B,self.n_grammars,len(self.grammar)+1))
414
+ for b,summary in enumerate(summaries):
415
+ for e, ss in summary.library.items():
416
+ for g,s in zip(self.library[e], ss):
417
+ assert g < self.n_grammars - 2
418
+ for p, production in enumerate(self.grammar.primitives):
419
+ uses[b,g,p] = s.uses.get(production, 0.)
420
+ uses[b,g,len(self.grammar)] = s.uses.get(Index(0), 0)
421
+
422
+ # noParent: this is the last network output
423
+ for p, production in enumerate(self.grammar.primitives):
424
+ uses[b, self.n_grammars - 1, p] = summary.noParent.uses.get(production, 0.)
425
+ uses[b, self.n_grammars - 1, G - 1] = summary.noParent.uses.get(Index(0), 0.)
426
+
427
+ # variableParent: this is the penultimate network output
428
+ for p, production in enumerate(self.grammar.primitives):
429
+ uses[b, self.n_grammars - 2, p] = summary.variableParent.uses.get(production, 0.)
430
+ uses[b, self.n_grammars - 2, G - 1] = summary.variableParent.uses.get(Index(0), 0.)
431
+
432
+ numerator = (logProductions*maybe_cuda(torch.tensor(uses).float(),use_cuda)).view(B,-1).sum(1)
433
+
434
+ constant = np.zeros(B)
435
+ for b,summary in enumerate(summaries):
436
+ constant[b] += summary.noParent.constant + summary.variableParent.constant
437
+ for ss in summary.library.values():
438
+ for s in ss:
439
+ constant[b] += s.constant
440
+
441
+ numerator += maybe_cuda(torch.tensor(constant).float(),use_cuda)
442
+
443
+ if True:
444
+
445
+ # Calculate the god-awful denominator
446
+ alternativeSet = set()
447
+ for summary in summaries:
448
+ for normalizer in summary.noParent.normalizers: alternativeSet.add(normalizer)
449
+ for normalizer in summary.variableParent.normalizers: alternativeSet.add(normalizer)
450
+ for ss in summary.library.values():
451
+ for s in ss:
452
+ for normalizer in s.normalizers: alternativeSet.add(normalizer)
453
+ alternativeSet = list(alternativeSet)
454
+
455
+ mask = np.zeros((len(alternativeSet), G))
456
+ for tau in range(len(alternativeSet)):
457
+ for p, production in enumerate(self.grammar.primitives):
458
+ mask[tau,p] = 0. if production in alternativeSet[tau] else NEGATIVEINFINITY
459
+ mask[tau, G - 1] = 0. if Index(0) in alternativeSet[tau] else NEGATIVEINFINITY
460
+ mask = maybe_cuda(torch.tensor(mask).float(), use_cuda)
461
+
462
+ z = mask.repeat(self.n_grammars,1,1).repeat(B,1,1,1) + \
463
+ logProductions.repeat(len(alternativeSet),1,1,1).transpose(0,1).transpose(1,2)
464
+ z = torch.logsumexp(z, 3) # pytorch 1.0 dependency
465
+
466
+ N = np.zeros((B, self.n_grammars, len(alternativeSet)))
467
+ for b, summary in enumerate(summaries):
468
+ for e, ss in summary.library.items():
469
+ for g,s in zip(self.library[e], ss):
470
+ assert g < self.n_grammars - 2
471
+ for r, alternatives in enumerate(alternativeSet):
472
+ N[b,g,r] = s.normalizers.get(alternatives, 0.)
473
+ # noParent: this is the last network output
474
+ for r, alternatives in enumerate(alternativeSet):
475
+ N[b,self.n_grammars - 1,r] = summary.noParent.normalizers.get(alternatives, 0.)
476
+ # variableParent: this is the penultimate network output
477
+ for r, alternatives in enumerate(alternativeSet):
478
+ N[b,self.n_grammars - 2,r] = summary.variableParent.normalizers.get(alternatives, 0.)
479
+ N = maybe_cuda(torch.tensor(N).float(), use_cuda)
480
+ denominator = (N*z).sum(1).sum(1)
481
+ else:
482
+ gs = [ self(xs[b]) for b in range(B) ]
483
+ denominator = torch.cat([ summary.denominator(g) for summary,g in zip(summaries, gs) ])
484
+
485
+
486
+
487
+
488
+
489
+ ll = numerator - denominator
490
+
491
+ if False: # verifying that batching works correctly
492
+ gs = [ self(xs[b]) for b in range(B) ]
493
+ _l = torch.cat([ summary.logLikelihood(g) for summary,g in zip(summaries, gs) ])
494
+ assert torch.all((ll - _l).abs() < 0.0001)
495
+ return ll
496
+
497
+
498
+
499
+ class ContextualGrammarNetwork(nn.Module):
500
+ """Like GrammarNetwork but ~contextual~"""
501
+ def __init__(self, inputDimensionality, grammar):
502
+ super(ContextualGrammarNetwork, self).__init__()
503
+
504
+ # library now just contains a list of indicies which go with each primitive
505
+ self.grammar = grammar
506
+ self.library = {}
507
+ self.n_grammars = 0
508
+ for prim in grammar.primitives:
509
+ numberOfArguments = len(prim.infer().functionArguments())
510
+ idx_list = list(range(self.n_grammars, self.n_grammars+numberOfArguments))
511
+ self.library[prim] = idx_list
512
+ self.n_grammars += numberOfArguments
513
+
514
+ # We had an extra grammar for when there is no parent and for when the parent is a variable
515
+ self.n_grammars += 2
516
+ self.network = nn.Linear(inputDimensionality, (self.n_grammars)*(len(grammar) + 1))
517
+
518
+
519
+ def grammarFromVector(self, logProductions):
520
+ return Grammar(logProductions[-1].view(1),
521
+ [(logProductions[k].view(1), t, program)
522
+ for k, (_, t, program) in enumerate(self.grammar.productions)],
523
+ continuationType=self.grammar.continuationType)
524
+
525
+ def forward(self, x):
526
+ assert len(x.size()) == 1, "contextual grammar doesn't currently support batching"
527
+
528
+ allVars = self.network(x).view(self.n_grammars, -1)
529
+ return ContextualGrammar(self.grammarFromVector(allVars[-1]), self.grammarFromVector(allVars[-2]),
530
+ {prim: [self.grammarFromVector(allVars[j]) for j in js]
531
+ for prim, js in self.library.items()} )
532
+
533
+ def batchedLogLikelihoods(self, xs, summaries):
534
+ use_cuda = xs.device.type == 'cuda'
535
+ """Takes as input BxinputDimensionality vector & B likelihood summaries;
536
+ returns B-dimensional vector containing log likelihood of each summary"""
537
+
538
+ B = xs.shape[0]
539
+ G = len(self.grammar) + 1
540
+ assert len(summaries) == B
541
+
542
+ # logProductions: Bx n_grammars x G
543
+ logProductions = self.network(xs).view(B, self.n_grammars, G)
544
+ # uses[b][g][p] is # uses of primitive p by summary b for parent g
545
+ uses = np.zeros((B,self.n_grammars,len(self.grammar)+1))
546
+ for b,summary in enumerate(summaries):
547
+ for e, ss in summary.library.items():
548
+ for g,s in zip(self.library[e], ss):
549
+ assert g < self.n_grammars - 2
550
+ for p, production in enumerate(self.grammar.primitives):
551
+ uses[b,g,p] = s.uses.get(production, 0.)
552
+ uses[b,g,len(self.grammar)] = s.uses.get(Index(0), 0)
553
+
554
+ # noParent: this is the last network output
555
+ for p, production in enumerate(self.grammar.primitives):
556
+ uses[b, self.n_grammars - 1, p] = summary.noParent.uses.get(production, 0.)
557
+ uses[b, self.n_grammars - 1, G - 1] = summary.noParent.uses.get(Index(0), 0.)
558
+
559
+ # variableParent: this is the penultimate network output
560
+ for p, production in enumerate(self.grammar.primitives):
561
+ uses[b, self.n_grammars - 2, p] = summary.variableParent.uses.get(production, 0.)
562
+ uses[b, self.n_grammars - 2, G - 1] = summary.variableParent.uses.get(Index(0), 0.)
563
+
564
+ numerator = (logProductions*maybe_cuda(torch.tensor(uses).float(),use_cuda)).view(B,-1).sum(1)
565
+
566
+ constant = np.zeros(B)
567
+ for b,summary in enumerate(summaries):
568
+ constant[b] += summary.noParent.constant + summary.variableParent.constant
569
+ for ss in summary.library.values():
570
+ for s in ss:
571
+ constant[b] += s.constant
572
+
573
+ numerator += maybe_cuda(torch.tensor(constant).float(),use_cuda)
574
+
575
+ # Calculate the god-awful denominator
576
+ alternativeSet = set()
577
+ for summary in summaries:
578
+ for normalizer in summary.noParent.normalizers: alternativeSet.add(normalizer)
579
+ for normalizer in summary.variableParent.normalizers: alternativeSet.add(normalizer)
580
+ for ss in summary.library.values():
581
+ for s in ss:
582
+ for normalizer in s.normalizers: alternativeSet.add(normalizer)
583
+ alternativeSet = list(alternativeSet)
584
+
585
+ mask = np.zeros((len(alternativeSet), G))
586
+ for tau in range(len(alternativeSet)):
587
+ for p, production in enumerate(self.grammar.primitives):
588
+ mask[tau,p] = 0. if production in alternativeSet[tau] else NEGATIVEINFINITY
589
+ mask[tau, G - 1] = 0. if Index(0) in alternativeSet[tau] else NEGATIVEINFINITY
590
+ mask = maybe_cuda(torch.tensor(mask).float(), use_cuda)
591
+
592
+ z = mask.repeat(self.n_grammars,1,1).repeat(B,1,1,1) + \
593
+ logProductions.repeat(len(alternativeSet),1,1,1).transpose(0,1).transpose(1,2)
594
+ z = torch.logsumexp(z, 3) # pytorch 1.0 dependency
595
+
596
+ N = np.zeros((B, self.n_grammars, len(alternativeSet)))
597
+ for b, summary in enumerate(summaries):
598
+ for e, ss in summary.library.items():
599
+ for g,s in zip(self.library[e], ss):
600
+ assert g < self.n_grammars - 2
601
+ for r, alternatives in enumerate(alternativeSet):
602
+ N[b,g,r] = s.normalizers.get(alternatives, 0.)
603
+ # noParent: this is the last network output
604
+ for r, alternatives in enumerate(alternativeSet):
605
+ N[b,self.n_grammars - 1,r] = summary.noParent.normalizers.get(alternatives, 0.)
606
+ # variableParent: this is the penultimate network output
607
+ for r, alternatives in enumerate(alternativeSet):
608
+ N[b,self.n_grammars - 2,r] = summary.variableParent.normalizers.get(alternatives, 0.)
609
+ N = maybe_cuda(torch.tensor(N).float(), use_cuda)
610
+
611
+
612
+
613
+ denominator = (N*z).sum(1).sum(1)
614
+ ll = numerator - denominator
615
+
616
+ if False: # verifying that batching works correctly
617
+ gs = [ self(xs[b]) for b in range(B) ]
618
+ _l = torch.cat([ summary.logLikelihood(g) for summary,g in zip(summaries, gs) ])
619
+ assert torch.all((ll - _l).abs() < 0.0001)
620
+
621
+ return ll
622
+
623
+
624
+ class RecognitionModel(nn.Module):
625
+ def __init__(self,featureExtractor,grammar,hidden=[64],activation="tanh",
626
+ rank=None,contextual=False,mask=False,
627
+ cuda=False,
628
+ previousRecognitionModel=None,
629
+ id=0):
630
+ super(RecognitionModel, self).__init__()
631
+ self.id = id
632
+ self.trained=False
633
+ self.use_cuda = cuda
634
+
635
+ self.featureExtractor = featureExtractor
636
+ # Sanity check - make sure that all of the parameters of the
637
+ # feature extractor were added to our parameters as well
638
+ if hasattr(featureExtractor, 'parameters'):
639
+ for parameter in featureExtractor.parameters():
640
+ assert any(myParameter is parameter for myParameter in self.parameters())
641
+
642
+ # Build the multilayer perceptron that is sandwiched between the feature extractor and the grammar
643
+ if activation == "sigmoid":
644
+ activation = nn.Sigmoid
645
+ elif activation == "relu":
646
+ activation = nn.ReLU
647
+ elif activation == "tanh":
648
+ activation = nn.Tanh
649
+ else:
650
+ raise Exception('Unknown activation function ' + str(activation))
651
+ self._MLP = nn.Sequential(*[ layer
652
+ for j in range(len(hidden))
653
+ for layer in [
654
+ nn.Linear(([featureExtractor.outputDimensionality] + hidden)[j],
655
+ hidden[j]),
656
+ activation()]])
657
+
658
+ self.entropy = Entropy()
659
+
660
+ if len(hidden) > 0:
661
+ self.outputDimensionality = self._MLP[-2].out_features
662
+ assert self.outputDimensionality == hidden[-1]
663
+ else:
664
+ self.outputDimensionality = self.featureExtractor.outputDimensionality
665
+
666
+ self.contextual = contextual
667
+ if self.contextual:
668
+ if mask:
669
+ self.grammarBuilder = ContextualGrammarNetwork_Mask(self.outputDimensionality, grammar)
670
+ else:
671
+ self.grammarBuilder = ContextualGrammarNetwork_LowRank(self.outputDimensionality, grammar, rank)
672
+ else:
673
+ self.grammarBuilder = GrammarNetwork(self.outputDimensionality, grammar)
674
+
675
+ self.grammar = ContextualGrammar.fromGrammar(grammar) if contextual else grammar
676
+ self.generativeModel = grammar
677
+
678
+ self._auxiliaryPrediction = nn.Linear(self.featureExtractor.outputDimensionality,
679
+ len(self.grammar.primitives))
680
+ self._auxiliaryLoss = nn.BCEWithLogitsLoss()
681
+
682
+ if cuda: self.cuda()
683
+
684
+ if previousRecognitionModel:
685
+ self._MLP.load_state_dict(previousRecognitionModel._MLP.state_dict())
686
+ self.featureExtractor.load_state_dict(previousRecognitionModel.featureExtractor.state_dict())
687
+
688
+ def auxiliaryLoss(self, frontier, features):
689
+ # Compute a vector of uses
690
+ ls = frontier.bestPosterior.program
691
+ def uses(summary):
692
+ if hasattr(summary, 'uses'):
693
+ return torch.tensor([ float(int(p in summary.uses))
694
+ for p in self.generativeModel.primitives ])
695
+ assert hasattr(summary, 'noParent')
696
+ u = uses(summary.noParent) + uses(summary.variableParent)
697
+ for ss in summary.library.values():
698
+ for s in ss:
699
+ u += uses(s)
700
+ return u
701
+ u = uses(ls)
702
+ u[u > 1.] = 1.
703
+ if self.use_cuda: u = u.cuda()
704
+ al = self._auxiliaryLoss(self._auxiliaryPrediction(features), u)
705
+ return al
706
+
707
+ def taskEmbeddings(self, tasks):
708
+ return {task: self.featureExtractor.featuresOfTask(task).data.cpu().numpy()
709
+ for task in tasks}
710
+
711
+ def forward(self, features):
712
+ """returns either a Grammar or a ContextualGrammar
713
+ Takes as input the output of featureExtractor.featuresOfTask"""
714
+ features = self._MLP(features)
715
+ return self.grammarBuilder(features)
716
+
717
+ def auxiliaryPrimitiveEmbeddings(self):
718
+ """Returns the actual outputDimensionality weight vectors for each of the primitives."""
719
+ auxiliaryWeights = self._auxiliaryPrediction.weight.data.cpu().numpy()
720
+ primitivesDict = {self.grammar.primitives[i] : auxiliaryWeights[i, :] for i in range(len(self.grammar.primitives))}
721
+ return primitivesDict
722
+
723
+ def grammarOfTask(self, task):
724
+ features = self.featureExtractor.featuresOfTask(task)
725
+ if features is None: return None
726
+ return self(features)
727
+
728
+ def grammarLogProductionsOfTask(self, task):
729
+ """Returns the grammar logits from non-contextual models."""
730
+
731
+ features = self.featureExtractor.featuresOfTask(task)
732
+ if features is None: return None
733
+
734
+ if hasattr(self, 'hiddenLayers'):
735
+ # Backward compatability with old checkpoints.
736
+ for layer in self.hiddenLayers:
737
+ features = self.activation(layer(features))
738
+ # return features
739
+ return self.noParent[1](features)
740
+ else:
741
+ features = self._MLP(features)
742
+
743
+ if self.contextual:
744
+ if hasattr(self.grammarBuilder, 'variableParent'):
745
+ return self.grammarBuilder.variableParent.logProductions(features)
746
+ elif hasattr(self.grammarBuilder, 'network'):
747
+ return self.grammarBuilder.network(features).view(-1)
748
+ elif hasattr(self.grammarBuilder, 'transitionMatrix'):
749
+ return self.grammarBuilder.transitionMatrix(features).view(-1)
750
+ else:
751
+ assert False
752
+ else:
753
+ return self.grammarBuilder.logProductions(features)
754
+
755
+ def grammarFeatureLogProductionsOfTask(self, task):
756
+ return torch.tensor(self.grammarOfTask(task).untorch().featureVector())
757
+
758
+ def grammarLogProductionDistanceToTask(self, task, tasks):
759
+ """Returns the cosine similarity of all other tasks to a given task."""
760
+ taskLogits = self.grammarLogProductionsOfTask(task).unsqueeze(0) # Change to [1, D]
761
+ assert taskLogits is not None, 'Grammar log productions are not defined for this task.'
762
+ otherTasks = [t for t in tasks if t is not task] # [nTasks -1 , D]
763
+
764
+ # Build matrix of all other tasks.
765
+ otherLogits = torch.stack([self.grammarLogProductionsOfTask(t) for t in otherTasks])
766
+ cos = nn.CosineSimilarity(dim=1, eps=1e-6)
767
+ cosMatrix = cos(taskLogits, otherLogits)
768
+ return cosMatrix.data.cpu().numpy()
769
+
770
+ def grammarEntropyOfTask(self, task):
771
+ """Returns the entropy of the grammar distribution from non-contextual models for a task."""
772
+ grammarLogProductionsOfTask = self.grammarLogProductionsOfTask(task)
773
+
774
+ if grammarLogProductionsOfTask is None: return None
775
+
776
+ if hasattr(self, 'entropy'):
777
+ return self.entropy(grammarLogProductionsOfTask)
778
+ else:
779
+ e = Entropy()
780
+ return e(grammarLogProductionsOfTask)
781
+
782
+ def taskAuxiliaryLossLayer(self, tasks):
783
+ return {task: self._auxiliaryPrediction(self.featureExtractor.featuresOfTask(task)).view(-1).data.cpu().numpy()
784
+ for task in tasks}
785
+
786
+ def taskGrammarFeatureLogProductions(self, tasks):
787
+ return {task: self.grammarFeatureLogProductionsOfTask(task).data.cpu().numpy()
788
+ for task in tasks}
789
+
790
+ def taskGrammarLogProductions(self, tasks):
791
+ return {task: self.grammarLogProductionsOfTask(task).data.cpu().numpy()
792
+ for task in tasks}
793
+
794
+ def taskGrammarStartProductions(self, tasks):
795
+ return {task: np.array([l for l,_1,_2 in g.productions ])
796
+ for task in tasks
797
+ for g in [self.grammarOfTask(task).untorch().noParent] }
798
+
799
+ def taskHiddenStates(self, tasks):
800
+ return {task: self._MLP(self.featureExtractor.featuresOfTask(task)).view(-1).data.cpu().numpy()
801
+ for task in tasks}
802
+
803
+ def taskGrammarEntropies(self, tasks):
804
+ return {task: self.grammarEntropyOfTask(task).data.cpu().numpy()
805
+ for task in tasks}
806
+
807
+ def frontierKL(self, frontier, auxiliary=False, vectorized=True):
808
+ features = self.featureExtractor.featuresOfTask(frontier.task)
809
+ if features is None:
810
+ return None, None
811
+ # Monte Carlo estimate: draw a sample from the frontier
812
+ entry = frontier.sample()
813
+
814
+ al = self.auxiliaryLoss(frontier, features if auxiliary else features.detach())
815
+
816
+ if not vectorized:
817
+ g = self(features)
818
+ return - entry.program.logLikelihood(g), al
819
+ else:
820
+ features = self._MLP(features).unsqueeze(0)
821
+
822
+ ll = self.grammarBuilder.batchedLogLikelihoods(features, [entry.program]).view(-1)
823
+ return -ll, al
824
+
825
+
826
+ def frontierBiasOptimal(self, frontier, auxiliary=False, vectorized=True):
827
+ if not vectorized:
828
+ features = self.featureExtractor.featuresOfTask(frontier.task)
829
+ if features is None: return None, None
830
+ al = self.auxiliaryLoss(frontier, features if auxiliary else features.detach())
831
+ g = self(features)
832
+ summaries = [entry.program for entry in frontier]
833
+ likelihoods = torch.cat([entry.program.logLikelihood(g) + entry.logLikelihood
834
+ for entry in frontier ])
835
+ best = likelihoods.max()
836
+ return -best, al
837
+
838
+ batchSize = len(frontier.entries)
839
+ features = self.featureExtractor.featuresOfTask(frontier.task)
840
+ if features is None: return None, None
841
+ al = self.auxiliaryLoss(frontier, features if auxiliary else features.detach())
842
+ features = self._MLP(features)
843
+ features = features.expand(batchSize, features.size(-1)) # TODO
844
+ lls = self.grammarBuilder.batchedLogLikelihoods(features, [entry.program for entry in frontier])
845
+ actual_ll = torch.Tensor([ entry.logLikelihood for entry in frontier])
846
+ lls = lls + (actual_ll.cuda() if self.use_cuda else actual_ll)
847
+ ml = -lls.max() #Beware that inputs to max change output type
848
+ return ml, al
849
+
850
+ def replaceProgramsWithLikelihoodSummaries(self, frontier):
851
+ return Frontier(
852
+ [FrontierEntry(
853
+ program=self.grammar.closedLikelihoodSummary(frontier.task.request, e.program),
854
+ logLikelihood=e.logLikelihood,
855
+ logPrior=e.logPrior) for e in frontier],
856
+ task=frontier.task)
857
+
858
+ def train(self, frontiers, _=None, steps=None, lr=0.001, topK=5, CPUs=1,
859
+ timeout=None, evaluationTimeout=0.001,
860
+ helmholtzFrontiers=[], helmholtzRatio=0., helmholtzBatch=500,
861
+ biasOptimal=None, defaultRequest=None, auxLoss=False, vectorized=True):
862
+ """
863
+ helmholtzRatio: What fraction of the training data should be forward samples from the generative model?
864
+ helmholtzFrontiers: Frontiers from programs enumerated from generative model (optional)
865
+ If helmholtzFrontiers is not provided then we will sample programs during training
866
+ """
867
+ assert (steps is not None) or (timeout is not None), \
868
+ "Cannot train recognition model without either a bound on the number of gradient steps or bound on the training time"
869
+ if steps is None: steps = 9999999
870
+ if biasOptimal is None: biasOptimal = len(helmholtzFrontiers) > 0
871
+
872
+ requests = [frontier.task.request for frontier in frontiers]
873
+ if len(requests) == 0 and helmholtzRatio > 0 and len(helmholtzFrontiers) == 0:
874
+ assert defaultRequest is not None, "You are trying to random Helmholtz training, but don't have any frontiers. Therefore we would not know the type of the program to sample. Try specifying defaultRequest=..."
875
+ requests = [defaultRequest]
876
+ frontiers = [frontier.topK(topK).normalize()
877
+ for frontier in frontiers if not frontier.empty]
878
+ if len(frontiers) == 0:
879
+ eprint("You didn't give me any nonempty replay frontiers to learn from. Going to learn from 100% Helmholtz samples")
880
+ helmholtzRatio = 1.
881
+
882
+ # Should we sample programs or use the enumerated programs?
883
+ randomHelmholtz = len(helmholtzFrontiers) == 0
884
+
885
+ class HelmholtzEntry:
886
+ def __init__(self, frontier, owner):
887
+ self.request = frontier.task.request
888
+ self.task = None
889
+ self.programs = [e.program for e in frontier]
890
+ self.frontier = Thunk(lambda: owner.replaceProgramsWithLikelihoodSummaries(frontier))
891
+ self.owner = owner
892
+
893
+ def clear(self): self.task = None
894
+
895
+ def calculateTask(self):
896
+ assert self.task is None
897
+ p = random.choice(self.programs)
898
+ return self.owner.featureExtractor.taskOfProgram(p, self.request)
899
+
900
+ def makeFrontier(self):
901
+ assert self.task is not None
902
+ f = Frontier(self.frontier.force().entries,
903
+ task=self.task)
904
+ return f
905
+
906
+
907
+
908
+
909
+ # Should we recompute tasks on the fly from Helmholtz? This
910
+ # should be done if the task is stochastic, or if there are
911
+ # different kinds of inputs on which it could be run. For
912
+ # example, lists and strings need this; towers and graphics do
913
+ # not. There is no harm in recomputed the tasks, it just
914
+ # wastes time.
915
+ if not hasattr(self.featureExtractor, 'recomputeTasks'):
916
+ self.featureExtractor.recomputeTasks = True
917
+ helmholtzFrontiers = [HelmholtzEntry(f, self)
918
+ for f in helmholtzFrontiers]
919
+ random.shuffle(helmholtzFrontiers)
920
+
921
+ helmholtzIndex = [0]
922
+ def getHelmholtz():
923
+ if randomHelmholtz:
924
+ if helmholtzIndex[0] >= len(helmholtzFrontiers):
925
+ updateHelmholtzTasks()
926
+ helmholtzIndex[0] = 0
927
+ return getHelmholtz()
928
+ helmholtzIndex[0] += 1
929
+ return helmholtzFrontiers[helmholtzIndex[0] - 1].makeFrontier()
930
+
931
+ f = helmholtzFrontiers[helmholtzIndex[0]]
932
+ if f.task is None:
933
+ with timing("Evaluated another batch of Helmholtz tasks"):
934
+ updateHelmholtzTasks()
935
+ return getHelmholtz()
936
+
937
+ helmholtzIndex[0] += 1
938
+ if helmholtzIndex[0] >= len(helmholtzFrontiers):
939
+ helmholtzIndex[0] = 0
940
+ random.shuffle(helmholtzFrontiers)
941
+ if self.featureExtractor.recomputeTasks:
942
+ for fp in helmholtzFrontiers:
943
+ fp.clear()
944
+ return getHelmholtz() # because we just cleared everything
945
+ assert f.task is not None
946
+ return f.makeFrontier()
947
+
948
+ def updateHelmholtzTasks():
949
+ updateCPUs = CPUs if hasattr(self.featureExtractor, 'parallelTaskOfProgram') and self.featureExtractor.parallelTaskOfProgram else 1
950
+ if updateCPUs > 1: eprint("Updating Helmholtz tasks with",updateCPUs,"CPUs",
951
+ "while using",getThisMemoryUsage(),"memory")
952
+
953
+ if randomHelmholtz:
954
+ newFrontiers = self.sampleManyHelmholtz(requests, helmholtzBatch, CPUs)
955
+ newEntries = []
956
+ for f in newFrontiers:
957
+ e = HelmholtzEntry(f,self)
958
+ e.task = f.task
959
+ newEntries.append(e)
960
+ helmholtzFrontiers.clear()
961
+ helmholtzFrontiers.extend(newEntries)
962
+ return
963
+
964
+ # Save some memory by freeing up the tasks as we go through them
965
+ if self.featureExtractor.recomputeTasks:
966
+ for hi in range(max(0, helmholtzIndex[0] - helmholtzBatch,
967
+ min(helmholtzIndex[0], len(helmholtzFrontiers)))):
968
+ helmholtzFrontiers[hi].clear()
969
+
970
+ if hasattr(self.featureExtractor, 'tasksOfPrograms'):
971
+ eprint("batching task calculation")
972
+ newTasks = self.featureExtractor.tasksOfPrograms(
973
+ [random.choice(hf.programs)
974
+ for hf in helmholtzFrontiers[helmholtzIndex[0]:helmholtzIndex[0] + helmholtzBatch] ],
975
+ [hf.request
976
+ for hf in helmholtzFrontiers[helmholtzIndex[0]:helmholtzIndex[0] + helmholtzBatch] ])
977
+ else:
978
+ newTasks = [hf.calculateTask()
979
+ for hf in helmholtzFrontiers[helmholtzIndex[0]:helmholtzIndex[0] + helmholtzBatch]]
980
+
981
+ """
982
+ # catwong: Disabled for ensemble training.
983
+ newTasks = \
984
+ parallelMap(updateCPUs,
985
+ lambda f: f.calculateTask(),
986
+ helmholtzFrontiers[helmholtzIndex[0]:helmholtzIndex[0] + helmholtzBatch],
987
+ seedRandom=True)
988
+ """
989
+ badIndices = []
990
+ endingIndex = min(helmholtzIndex[0] + helmholtzBatch, len(helmholtzFrontiers))
991
+ for i in range(helmholtzIndex[0], endingIndex):
992
+ helmholtzFrontiers[i].task = newTasks[i - helmholtzIndex[0]]
993
+ if helmholtzFrontiers[i].task is None: badIndices.append(i)
994
+ # Permanently kill anything which failed to give a task
995
+ for i in reversed(badIndices):
996
+ assert helmholtzFrontiers[i].task is None
997
+ del helmholtzFrontiers[i]
998
+
999
+
1000
+ # We replace each program in the frontier with its likelihoodSummary
1001
+ # This is because calculating likelihood summaries requires juggling types
1002
+ # And type stuff is expensive!
1003
+ frontiers = [self.replaceProgramsWithLikelihoodSummaries(f).normalize()
1004
+ for f in frontiers]
1005
+
1006
+ eprint("(ID=%d): Training a recognition model from %d frontiers, %d%% Helmholtz, feature extractor %s." % (
1007
+ self.id, len(frontiers), int(helmholtzRatio * 100), self.featureExtractor.__class__.__name__))
1008
+ eprint("(ID=%d): Got %d Helmholtz frontiers - random Helmholtz training? : %s"%(
1009
+ self.id, len(helmholtzFrontiers), len(helmholtzFrontiers) == 0))
1010
+ eprint("(ID=%d): Contextual? %s" % (self.id, str(self.contextual)))
1011
+ eprint("(ID=%d): Bias optimal? %s" % (self.id, str(biasOptimal)))
1012
+ eprint(f"(ID={self.id}): Aux loss? {auxLoss} (n.b. we train a 'auxiliary' classifier anyway - this controls if gradients propagate back to the future extractor)")
1013
+
1014
+ # The number of Helmholtz samples that we generate at once
1015
+ # Should only affect performance and shouldn't affect anything else
1016
+ helmholtzSamples = []
1017
+
1018
+ optimizer = torch.optim.Adam(self.parameters(), lr=lr, eps=1e-3, amsgrad=True)
1019
+ start = time.time()
1020
+ losses, descriptionLengths, realLosses, dreamLosses, realMDL, dreamMDL = [], [], [], [], [], []
1021
+ classificationLosses = []
1022
+ totalGradientSteps = 0
1023
+ epochs = 9999999
1024
+ for i in range(1, epochs + 1):
1025
+ if timeout and time.time() - start > timeout:
1026
+ break
1027
+
1028
+ if totalGradientSteps > steps:
1029
+ break
1030
+
1031
+ if helmholtzRatio < 1.:
1032
+ permutedFrontiers = list(frontiers)
1033
+ random.shuffle(permutedFrontiers)
1034
+ else:
1035
+ permutedFrontiers = [None]
1036
+
1037
+ finishedSteps = False
1038
+ for frontier in permutedFrontiers:
1039
+ # Randomly decide whether to sample from the generative model
1040
+ dreaming = random.random() < helmholtzRatio
1041
+ if dreaming: frontier = getHelmholtz()
1042
+ self.zero_grad()
1043
+ loss, classificationLoss = \
1044
+ self.frontierBiasOptimal(frontier, auxiliary=auxLoss, vectorized=vectorized) if biasOptimal \
1045
+ else self.frontierKL(frontier, auxiliary=auxLoss, vectorized=vectorized)
1046
+ if loss is None:
1047
+ if not dreaming:
1048
+ eprint("ERROR: Could not extract features during experience replay.")
1049
+ eprint("Task is:",frontier.task)
1050
+ eprint("Aborting - we need to be able to extract features of every actual task.")
1051
+ assert False
1052
+ else:
1053
+ continue
1054
+ if is_torch_invalid(loss):
1055
+ eprint("Invalid real-data loss!")
1056
+ else:
1057
+ (loss + classificationLoss).backward()
1058
+ classificationLosses.append(classificationLoss.data.item())
1059
+ optimizer.step()
1060
+ totalGradientSteps += 1
1061
+ losses.append(loss.data.item())
1062
+ descriptionLengths.append(min(-e.logPrior for e in frontier))
1063
+ if dreaming:
1064
+ dreamLosses.append(losses[-1])
1065
+ dreamMDL.append(descriptionLengths[-1])
1066
+ else:
1067
+ realLosses.append(losses[-1])
1068
+ realMDL.append(descriptionLengths[-1])
1069
+ if totalGradientSteps > steps:
1070
+ break # Stop iterating, then print epoch and loss, then break to finish.
1071
+
1072
+ if (i == 1 or i % 10 == 0) and losses:
1073
+ eprint("(ID=%d): " % self.id, "Epoch", i, "Loss", mean(losses))
1074
+ if realLosses and dreamLosses:
1075
+ eprint("(ID=%d): " % self.id, "\t\t(real loss): ", mean(realLosses), "\t(dream loss):", mean(dreamLosses))
1076
+ eprint("(ID=%d): " % self.id, "\tvs MDL (w/o neural net)", mean(descriptionLengths))
1077
+ if realMDL and dreamMDL:
1078
+ eprint("\t\t(real MDL): ", mean(realMDL), "\t(dream MDL):", mean(dreamMDL))
1079
+ eprint("(ID=%d): " % self.id, "\t%d cumulative gradient steps. %f steps/sec"%(totalGradientSteps,
1080
+ totalGradientSteps/(time.time() - start)))
1081
+ eprint("(ID=%d): " % self.id, "\t%d-way auxiliary classification loss"%len(self.grammar.primitives),sum(classificationLosses)/len(classificationLosses))
1082
+ losses, descriptionLengths, realLosses, dreamLosses, realMDL, dreamMDL = [], [], [], [], [], []
1083
+ classificationLosses = []
1084
+ gc.collect()
1085
+
1086
+ eprint("(ID=%d): " % self.id, " Trained recognition model in",time.time() - start,"seconds")
1087
+ self.trained=True
1088
+ return self
1089
+
1090
+ def sampleHelmholtz(self, requests, statusUpdate=None, seed=None):
1091
+ if seed is not None:
1092
+ random.seed(seed)
1093
+ request = random.choice(requests)
1094
+
1095
+ program = self.generativeModel.sample(request, maximumDepth=6, maxAttempts=100)
1096
+ if program is None:
1097
+ return None
1098
+ task = self.featureExtractor.taskOfProgram(program, request)
1099
+
1100
+ if statusUpdate is not None:
1101
+ flushEverything()
1102
+ if task is None:
1103
+ return None
1104
+
1105
+ if hasattr(self.featureExtractor, 'lexicon'):
1106
+ if self.featureExtractor.tokenize(task.examples) is None:
1107
+ return None
1108
+
1109
+ ll = self.generativeModel.logLikelihood(request, program)
1110
+ frontier = Frontier([FrontierEntry(program=program,
1111
+ logLikelihood=0., logPrior=ll)],
1112
+ task=task)
1113
+ return frontier
1114
+
1115
+ def sampleManyHelmholtz(self, requests, N, CPUs):
1116
+ eprint("Sampling %d programs from the prior on %d CPUs..." % (N, CPUs))
1117
+ flushEverything()
1118
+ frequency = N / 50
1119
+ startingSeed = random.random()
1120
+
1121
+ # Sequentially for ensemble training.
1122
+ samples = [self.sampleHelmholtz(requests,
1123
+ statusUpdate='.' if n % frequency == 0 else None,
1124
+ seed=startingSeed + n) for n in range(N)]
1125
+
1126
+ # (cathywong) Disabled for ensemble training.
1127
+ # samples = parallelMap(
1128
+ # 1,
1129
+ # lambda n: self.sampleHelmholtz(requests,
1130
+ # statusUpdate='.' if n % frequency == 0 else None,
1131
+ # seed=startingSeed + n),
1132
+ # range(N))
1133
+ eprint()
1134
+ flushEverything()
1135
+ samples = [z for z in samples if z is not None]
1136
+ eprint()
1137
+ eprint("Got %d/%d valid samples." % (len(samples), N))
1138
+ flushEverything()
1139
+
1140
+ return samples
1141
+
1142
+ def enumerateFrontiers(self,
1143
+ tasks,
1144
+ enumerationTimeout=None,
1145
+ testing=False,
1146
+ solver=None,
1147
+ CPUs=1,
1148
+ frontierSize=None,
1149
+ maximumFrontier=None,
1150
+ evaluationTimeout=None):
1151
+ with timing("Evaluated recognition model"):
1152
+ grammars = {task: self.grammarOfTask(task)
1153
+ for task in tasks}
1154
+ #untorch seperately to make sure you filter out None grammars
1155
+ grammars = {task: grammar.untorch() for task, grammar in grammars.items() if grammar is not None}
1156
+
1157
+ return multicoreEnumeration(grammars, tasks,
1158
+ testing=testing,
1159
+ solver=solver,
1160
+ enumerationTimeout=enumerationTimeout,
1161
+ CPUs=CPUs, maximumFrontier=maximumFrontier,
1162
+ evaluationTimeout=evaluationTimeout)
1163
+
1164
+
1165
+ class RecurrentFeatureExtractor(nn.Module):
1166
+ def __init__(self, _=None,
1167
+ tasks=None,
1168
+ cuda=False,
1169
+ # what are the symbols that can occur in the inputs and
1170
+ # outputs
1171
+ lexicon=None,
1172
+ # how many hidden units
1173
+ H=32,
1174
+ # Should the recurrent units be bidirectional?
1175
+ bidirectional=False,
1176
+ # What should be the timeout for trying to construct Helmholtz tasks?
1177
+ helmholtzTimeout=0.25,
1178
+ # What should be the timeout for running a Helmholtz program?
1179
+ helmholtzEvaluationTimeout=0.01):
1180
+ super(RecurrentFeatureExtractor, self).__init__()
1181
+
1182
+ assert tasks is not None, "You must provide a list of all of the tasks, both those that have been hit and those that have not been hit. Input examples are sampled from these tasks."
1183
+
1184
+ # maps from a requesting type to all of the inputs that we ever saw with that request
1185
+ self.requestToInputs = {
1186
+ tp: [list(map(fst, t.examples)) for t in tasks if t.request == tp ]
1187
+ for tp in {t.request for t in tasks}
1188
+ }
1189
+
1190
+ inputTypes = {t
1191
+ for task in tasks
1192
+ for t in task.request.functionArguments()}
1193
+ # maps from a type to all of the inputs that we ever saw having that type
1194
+ self.argumentsWithType = {
1195
+ tp: [ x
1196
+ for t in tasks
1197
+ for xs,_ in t.examples
1198
+ for tpp, x in zip(t.request.functionArguments(), xs)
1199
+ if tpp == tp]
1200
+ for tp in inputTypes
1201
+ }
1202
+ self.requestToNumberOfExamples = {
1203
+ tp: [ len(t.examples)
1204
+ for t in tasks if t.request == tp ]
1205
+ for tp in {t.request for t in tasks}
1206
+ }
1207
+ self.helmholtzTimeout = helmholtzTimeout
1208
+ self.helmholtzEvaluationTimeout = helmholtzEvaluationTimeout
1209
+ self.parallelTaskOfProgram = True
1210
+
1211
+ assert lexicon
1212
+ self.specialSymbols = [
1213
+ "STARTING", # start of entire sequence
1214
+ "ENDING", # ending of entire sequence
1215
+ "STARTOFOUTPUT", # begins the start of the output
1216
+ "ENDOFINPUT" # delimits the ending of an input - we might have multiple inputs
1217
+ ]
1218
+ lexicon += self.specialSymbols
1219
+ encoder = nn.Embedding(len(lexicon), H)
1220
+ self.encoder = encoder
1221
+
1222
+ self.H = H
1223
+ self.bidirectional = bidirectional
1224
+
1225
+ layers = 1
1226
+
1227
+ model = nn.GRU(H, H, layers, bidirectional=bidirectional)
1228
+ self.model = model
1229
+
1230
+ self.use_cuda = cuda
1231
+ self.lexicon = lexicon
1232
+ self.symbolToIndex = {
1233
+ symbol: index for index,
1234
+ symbol in enumerate(lexicon)}
1235
+ self.startingIndex = self.symbolToIndex["STARTING"]
1236
+ self.endingIndex = self.symbolToIndex["ENDING"]
1237
+ self.startOfOutputIndex = self.symbolToIndex["STARTOFOUTPUT"]
1238
+ self.endOfInputIndex = self.symbolToIndex["ENDOFINPUT"]
1239
+
1240
+ # Maximum number of inputs/outputs we will run the recognition
1241
+ # model on per task
1242
+ # This is an optimization hack
1243
+ self.MAXINPUTS = 100
1244
+
1245
+ if cuda: self.cuda()
1246
+
1247
+ @property
1248
+ def outputDimensionality(self): return self.H
1249
+
1250
+ # modify examples before forward (to turn them into iterables of lexicon)
1251
+ # you should override this if needed
1252
+ def tokenize(self, x): return x
1253
+
1254
+ def symbolEmbeddings(self):
1255
+ return {s: self.encoder(variable([self.symbolToIndex[s]])).squeeze(
1256
+ 0).data.cpu().numpy() for s in self.lexicon if not (s in self.specialSymbols)}
1257
+
1258
+ def packExamples(self, examples):
1259
+ """IMPORTANT! xs must be sorted in decreasing order of size because pytorch is stupid"""
1260
+ es = []
1261
+ sizes = []
1262
+ for xs, y in examples:
1263
+ e = [self.startingIndex]
1264
+ for x in xs:
1265
+ for s in x:
1266
+ e.append(self.symbolToIndex[s])
1267
+ e.append(self.endOfInputIndex)
1268
+ e.append(self.startOfOutputIndex)
1269
+ for s in y:
1270
+ e.append(self.symbolToIndex[s])
1271
+ e.append(self.endingIndex)
1272
+ if es != []:
1273
+ assert len(e) <= len(es[-1]), \
1274
+ "Examples must be sorted in decreasing order of their tokenized size. This should be transparently handled in recognition.py, so if this assertion fails it isn't your fault as a user of EC but instead is a bug inside of EC."
1275
+ es.append(e)
1276
+ sizes.append(len(e))
1277
+
1278
+ m = max(sizes)
1279
+ # padding
1280
+ for j, e in enumerate(es):
1281
+ es[j] += [self.endingIndex] * (m - len(e))
1282
+
1283
+ x = variable(es, cuda=self.use_cuda)
1284
+ x = self.encoder(x)
1285
+ # x: (batch size, maximum length, E)
1286
+ x = x.permute(1, 0, 2)
1287
+ # x: TxBxE
1288
+ x = pack_padded_sequence(x, sizes)
1289
+ return x, sizes
1290
+
1291
+ def examplesEncoding(self, examples):
1292
+ examples = sorted(examples, key=lambda xs_y: sum(
1293
+ len(z) + 1 for z in xs_y[0]) + len(xs_y[1]), reverse=True)
1294
+ x, sizes = self.packExamples(examples)
1295
+ outputs, hidden = self.model(x)
1296
+ # outputs, sizes = pad_packed_sequence(outputs)
1297
+ # I don't know whether to return the final output or the final hidden
1298
+ # activations...
1299
+ return hidden[0, :, :] + hidden[1, :, :]
1300
+
1301
+ def forward(self, examples):
1302
+ tokenized = self.tokenize(examples)
1303
+ if not tokenized:
1304
+ return None
1305
+
1306
+ if hasattr(self, 'MAXINPUTS') and len(tokenized) > self.MAXINPUTS:
1307
+ tokenized = list(tokenized)
1308
+ random.shuffle(tokenized)
1309
+ tokenized = tokenized[:self.MAXINPUTS]
1310
+ e = self.examplesEncoding(tokenized)
1311
+ # max pool
1312
+ # e,_ = e.max(dim = 0)
1313
+
1314
+ # take the average activations across all of the examples
1315
+ # I think this might be better because we might be testing on data
1316
+ # which has far more o far fewer examples then training
1317
+ e = e.mean(dim=0)
1318
+ return e
1319
+
1320
+ def featuresOfTask(self, t):
1321
+ if hasattr(self, 'useFeatures'):
1322
+ f = self(t.features)
1323
+ else:
1324
+ # Featurize the examples directly.
1325
+ f = self(t.examples)
1326
+ return f
1327
+
1328
+ def taskOfProgram(self, p, tp):
1329
+ # half of the time we randomly mix together inputs
1330
+ # this gives better generalization on held out tasks
1331
+ # the other half of the time we train on sets of inputs in the training data
1332
+ # this gives better generalization on unsolved training tasks
1333
+ if random.random() < 0.5:
1334
+ def randomInput(t): return random.choice(self.argumentsWithType[t])
1335
+ # Loop over the inputs in a random order and pick the first ones that
1336
+ # doesn't generate an exception
1337
+
1338
+ startTime = time.time()
1339
+ examples = []
1340
+ while True:
1341
+ # TIMEOUT! this must not be a very good program
1342
+ if time.time() - startTime > self.helmholtzTimeout: return None
1343
+
1344
+ # Grab some random inputs
1345
+ xs = [randomInput(t) for t in tp.functionArguments()]
1346
+ try:
1347
+ y = runWithTimeout(lambda: p.runWithArguments(xs), self.helmholtzEvaluationTimeout)
1348
+ examples.append((tuple(xs),y))
1349
+ if len(examples) >= random.choice(self.requestToNumberOfExamples[tp]):
1350
+ return Task("Helmholtz", tp, examples)
1351
+ except: continue
1352
+
1353
+ else:
1354
+ candidateInputs = list(self.requestToInputs[tp])
1355
+ random.shuffle(candidateInputs)
1356
+ for xss in candidateInputs:
1357
+ ys = []
1358
+ for xs in xss:
1359
+ try: y = runWithTimeout(lambda: p.runWithArguments(xs), self.helmholtzEvaluationTimeout)
1360
+ except: break
1361
+ ys.append(y)
1362
+ if len(ys) == len(xss):
1363
+ return Task("Helmholtz", tp, list(zip(xss, ys)))
1364
+ return None
1365
+
1366
+
1367
+
1368
+ class LowRank(nn.Module):
1369
+ """
1370
+ Module that outputs a rank R matrix of size m by n from input of size i.
1371
+ """
1372
+ def __init__(self, i, m, n, r):
1373
+ """
1374
+ i: input dimension
1375
+ m: output rows
1376
+ n: output columns
1377
+ r: maximum rank. if this is None then the output will be full-rank
1378
+ """
1379
+ super(LowRank, self).__init__()
1380
+
1381
+ self.m = m
1382
+ self.n = n
1383
+
1384
+ maximumPossibleRank = min(m, n)
1385
+ if r is None: r = maximumPossibleRank
1386
+
1387
+ if r < maximumPossibleRank:
1388
+ self.factored = True
1389
+ self.A = nn.Linear(i, m*r)
1390
+ self.B = nn.Linear(i, n*r)
1391
+ self.r = r
1392
+ else:
1393
+ self.factored = False
1394
+ self.M = nn.Linear(i, m*n)
1395
+
1396
+ def forward(self, x):
1397
+ sz = x.size()
1398
+ if len(sz) == 1:
1399
+ B = 1
1400
+ x = x.unsqueeze(0)
1401
+ needToSqueeze = True
1402
+ elif len(sz) == 2:
1403
+ B = sz[0]
1404
+ needToSqueeze = False
1405
+ else:
1406
+ assert False, "LowRank expects either a 1-dimensional tensor or a 2-dimensional tensor"
1407
+
1408
+ if self.factored:
1409
+ a = self.A(x).view(B, self.m, self.r)
1410
+ b = self.B(x).view(B, self.r, self.n)
1411
+ y = a @ b
1412
+ else:
1413
+ y = self.M(x).view(B, self.m, self.n)
1414
+ if needToSqueeze:
1415
+ y = y.squeeze(0)
1416
+ return y
1417
+
1418
+
1419
+
1420
+
1421
+ class DummyFeatureExtractor(nn.Module):
1422
+ def __init__(self, tasks, testingTasks=[], cuda=False):
1423
+ super(DummyFeatureExtractor, self).__init__()
1424
+ self.outputDimensionality = 1
1425
+ self.recomputeTasks = False
1426
+ def featuresOfTask(self, t):
1427
+ return variable([0.]).float()
1428
+ def featuresOfTasks(self, ts):
1429
+ return variable([[0.]]*len(ts)).float()
1430
+ def taskOfProgram(self, p, t):
1431
+ return Task("dummy task", t, [])
1432
+
1433
+ class RandomFeatureExtractor(nn.Module):
1434
+ def __init__(self, tasks):
1435
+ super(RandomFeatureExtractor, self).__init__()
1436
+ self.outputDimensionality = 1
1437
+ self.recomputeTasks = False
1438
+ def featuresOfTask(self, t):
1439
+ return variable([random.random()]).float()
1440
+ def featuresOfTasks(self, ts):
1441
+ return variable([[random.random()] for _ in range(len(ts)) ]).float()
1442
+ def taskOfProgram(self, p, t):
1443
+ return Task("dummy task", t, [])
1444
+
1445
+ class Flatten(nn.Module):
1446
+ def __init__(self):
1447
+ super(Flatten, self).__init__()
1448
+
1449
+ def forward(self, x):
1450
+ return x.view(x.size(0), -1)
1451
+
1452
+ class ImageFeatureExtractor(nn.Module):
1453
+ def __init__(self, inputImageDimension, resizedDimension=None,
1454
+ channels=1):
1455
+ super(ImageFeatureExtractor, self).__init__()
1456
+
1457
+ self.resizedDimension = resizedDimension or inputImageDimension
1458
+ self.inputImageDimension = inputImageDimension
1459
+ self.channels = channels
1460
+
1461
+ def conv_block(in_channels, out_channels):
1462
+ return nn.Sequential(
1463
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
1464
+ # nn.BatchNorm2d(out_channels),
1465
+ nn.ReLU(),
1466
+ nn.MaxPool2d(2)
1467
+ )
1468
+
1469
+ # channels for hidden
1470
+ hid_dim = 64
1471
+ z_dim = 64
1472
+
1473
+ self.encoder = nn.Sequential(
1474
+ conv_block(channels, hid_dim),
1475
+ conv_block(hid_dim, hid_dim),
1476
+ conv_block(hid_dim, hid_dim),
1477
+ conv_block(hid_dim, z_dim),
1478
+ Flatten()
1479
+ )
1480
+
1481
+ # Each layer of the encoder halves the dimension, except for the last layer which flattens
1482
+ outputImageDimensionality = self.resizedDimension/(2**(len(self.encoder) - 1))
1483
+ self.outputDimensionality = int(z_dim*outputImageDimensionality*outputImageDimensionality)
1484
+
1485
+ def forward(self, v):
1486
+ """1 channel: v: BxWxW or v:WxW
1487
+ > 1 channel: v: BxCxWxW or v:CxWxW"""
1488
+
1489
+ insertBatch = False
1490
+ variabled = variable(v).float()
1491
+ if self.channels == 1: # insert channel dimension
1492
+ if len(variabled.shape) == 3: # batching
1493
+ variabled = variabled[:,None,:,:]
1494
+ elif len(variabled.shape) == 2: # no batching
1495
+ variabled = variabled[None,:,:]
1496
+ insertBatch = True
1497
+ else: assert False
1498
+ else: # expect to have a channel dimension
1499
+ if len(variabled.shape) == 4:
1500
+ pass
1501
+ elif len(variabled.shape) == 3:
1502
+ insertBatch = True
1503
+ else: assert False
1504
+
1505
+ if insertBatch: variabled = torch.unsqueeze(variabled, 0)
1506
+
1507
+ y = self.encoder(variabled)
1508
+ if insertBatch: y = y[0,:]
1509
+ return y
1510
+
1511
+ class JSONFeatureExtractor(object):
1512
+ def __init__(self, tasks, cudaFalse):
1513
+ # self.averages, self.deviations = Task.featureMeanAndStandardDeviation(tasks)
1514
+ # self.outputDimensionality = len(self.averages)
1515
+ self.cuda = cuda
1516
+ self.tasks = tasks
1517
+
1518
+ def stringify(self, x):
1519
+ # No whitespace #maybe kill the seperators
1520
+ return json.dumps(x, separators=(',', ':'))
1521
+
1522
+ def featuresOfTask(self, t):
1523
+ # >>> t.request to get the type
1524
+ # >>> t.examples to get input/output examples
1525
+ # this might actually be okay, because the input should just be nothing
1526
+ #return [(self.stringify(inputs), self.stringify(output))
1527
+ # for (inputs, output) in t.examples]
1528
+ return [(list(output),) for (inputs, output) in t.examples]
dreamcoder/task.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.program import *
2
+ from dreamcoder.differentiation import *
3
+
4
+ import signal
5
+
6
+
7
+ class EvaluationTimeout(Exception):
8
+ pass
9
+
10
+
11
+ EVALUATIONTABLE = {}
12
+
13
+
14
+ class Task(object):
15
+ def __init__(self, name, request, examples, features=None, cache=False):
16
+ '''request: the type of this task
17
+ examples: list of tuples of (input, output). input should be a tuple, with one entry for each argument
18
+ cache: should program evaluations be cached?
19
+ features: list of floats.'''
20
+ self.cache = cache
21
+ self.features = features
22
+ self.request = request
23
+ self.name = name
24
+ self.examples = examples
25
+ if len(self.examples) > 0:
26
+ assert all(len(xs) == len(examples[0][0])
27
+ for xs, _ in examples), \
28
+ "(for task %s) FATAL: Number of arguments varies." % name
29
+
30
+ def __str__(self):
31
+ if self.supervision is None:
32
+ return self.name
33
+ else:
34
+ return self.name + " (%s)"%self.supervision
35
+
36
+ def __repr__(self):
37
+ return "Task(name={self.name}, request={self.request}, examples={self.examples}"\
38
+ .format(self=self)
39
+
40
+ def __eq__(self, o): return self.name == o.name
41
+
42
+ def __ne__(self, o): return not (self == o)
43
+
44
+ def __hash__(self): return hash(self.name)
45
+
46
+ def describe(self):
47
+ description = ["%s : %s" % (self.name, self.request)]
48
+ for xs, y in self.examples:
49
+ if len(xs) == 1:
50
+ description.append("f(%s) = %s" % (xs[0], y))
51
+ else:
52
+ description.append("f%s = %s" % (xs, y))
53
+ return "\n".join(description)
54
+
55
+ def predict(self, f, x):
56
+ for a in x:
57
+ f = f(a)
58
+ return f
59
+
60
+ @property
61
+ def supervision(self):
62
+ if not hasattr(self, 'supervisedSolution'): return None
63
+ return self.supervisedSolution
64
+
65
+ def check(self, e, timeout=None):
66
+ if timeout is not None:
67
+ def timeoutCallBack(_1, _2): raise EvaluationTimeout()
68
+ try:
69
+ signal.signal(signal.SIGVTALRM, timeoutCallBack)
70
+ signal.setitimer(signal.ITIMER_VIRTUAL, timeout)
71
+
72
+ try:
73
+ f = e.evaluate([])
74
+ except IndexError:
75
+ # free variable
76
+ return False
77
+ except Exception as e:
78
+ eprint("Exception during evaluation:", e)
79
+ return False
80
+
81
+ for x, y in self.examples:
82
+ if self.cache and (x, e) in EVALUATIONTABLE:
83
+ p = EVALUATIONTABLE[(x, e)]
84
+ else:
85
+ try:
86
+ p = self.predict(f, x)
87
+ except BaseException:
88
+ p = None
89
+ if self.cache:
90
+ EVALUATIONTABLE[(x, e)] = p
91
+ if p != y:
92
+ if timeout is not None:
93
+ signal.signal(signal.SIGVTALRM, lambda *_: None)
94
+ signal.setitimer(signal.ITIMER_VIRTUAL, 0)
95
+ return False
96
+
97
+ return True
98
+ # except e:
99
+ # eprint(e)
100
+ # assert(False)
101
+ except EvaluationTimeout:
102
+ eprint("Timed out while evaluating", e)
103
+ return False
104
+ finally:
105
+ if timeout is not None:
106
+ signal.signal(signal.SIGVTALRM, lambda *_: None)
107
+ signal.setitimer(signal.ITIMER_VIRTUAL, 0)
108
+
109
+ def logLikelihood(self, e, timeout=None):
110
+ if self.check(e, timeout):
111
+ return 0.0
112
+ else:
113
+ return NEGATIVEINFINITY
114
+
115
+ @staticmethod
116
+ def featureMeanAndStandardDeviation(tasks):
117
+ dimension = len(tasks[0].features)
118
+ averages = [sum(t.features[j] for t in tasks) / float(len(tasks))
119
+ for j in range(dimension)]
120
+ variances = [sum((t.features[j] -
121
+ averages[j])**2 for t in tasks) /
122
+ float(len(tasks)) for j in range(dimension)]
123
+ standardDeviations = [v**0.5 for v in variances]
124
+ for j, s in enumerate(standardDeviations):
125
+ if s == 0.:
126
+ eprint(
127
+ "WARNING: Feature %d is always %f" %
128
+ (j + 1, averages[j]))
129
+ return averages, standardDeviations
130
+
131
+ def as_json_dict(self):
132
+ return {
133
+ "name": self.name,
134
+ "request": str(self.request),
135
+ "examples": [{"inputs": x, "output": y} for x, y in self.examples]
136
+ }
137
+
138
+
139
+ class DifferentiableTask(Task):
140
+
141
+ def __init__(self, name, request, examples, _=None,
142
+ features=None, BIC=1., loss=None, likelihoodThreshold=None,
143
+ steps=50, restarts=300, lr=0.5, decay=0.5, grow=1.2, actualParameters=None,
144
+ temperature=1., maxParameters=None, clipLoss=None, clipOutput=None):
145
+ assert loss is not None
146
+ self.temperature = temperature
147
+ self.actualParameters = actualParameters
148
+ self.maxParameters = maxParameters
149
+ self.loss = loss
150
+ self.BIC = BIC
151
+ self.likelihoodThreshold = likelihoodThreshold
152
+
153
+ arguments = {"parameterPenalty": BIC * math.log(len(examples)),
154
+ "temperature": temperature,
155
+ "steps": steps, "restarts": restarts, "lr": lr, "decay": decay, "grow": grow,
156
+ "maxParameters": maxParameters,
157
+ "lossThreshold": -likelihoodThreshold}
158
+ if clipLoss is not None: arguments['clipLoss'] = float(clipLoss)
159
+ if clipOutput is not None: arguments['clipOutput'] = float(clipOutput)
160
+ if actualParameters is not None: arguments['actualParameters'] = int(actualParameters)
161
+
162
+ self.specialTask = ("differentiable",
163
+ arguments)
164
+
165
+ super(
166
+ DifferentiableTask,
167
+ self).__init__(
168
+ name,
169
+ request,
170
+ examples,
171
+ features,
172
+ cache=False)
173
+
174
+ def logLikelihood(self, e, timeout=None):
175
+ assert timeout is None, "timeout not implemented for differentiable tasks, but not for any good reason."
176
+ e, parameters = PlaceholderVisitor.execute(e)
177
+ if self.maxParameters is not None and len(
178
+ parameters) > self.maxParameters:
179
+ return NEGATIVEINFINITY
180
+ if self.actualParameters is not None and len(
181
+ parameters) > self.actualParameters:
182
+ return NEGATIVEINFINITY
183
+ f = e.evaluate([])
184
+
185
+ loss = sum(self.loss(self.predict(f, xs), y)
186
+ for xs, y in self.examples) / float(len(self.examples))
187
+ if isinstance(loss, DN):
188
+ try:
189
+ loss = loss.restartingOptimize(
190
+ parameters,
191
+ lr=self.specialTask[1]["lr"],
192
+ steps=self.specialTask[1]["steps"],
193
+ decay=self.specialTask[1]["decay"],
194
+ grow=self.specialTask[1]["grow"],
195
+ attempts=self.specialTask[1]["restarts"],
196
+ update=None)
197
+ except InvalidLoss:
198
+ loss = POSITIVEINFINITY
199
+
200
+ # BIC penalty
201
+ penalty = self.BIC * len(parameters) * math.log(len(self.examples))
202
+
203
+ if self.likelihoodThreshold is not None:
204
+ if loss > -self.likelihoodThreshold:
205
+ return NEGATIVEINFINITY
206
+ else:
207
+ return -penalty
208
+ else:
209
+ return -loss / self.temperature - penalty
210
+
211
+
212
+ def squaredErrorLoss(prediction, target):
213
+ d = prediction - target
214
+ return d * d
215
+
216
+
217
+ def l1loss(prediction, target):
218
+ return abs(prediction - target)
219
+
220
+
221
+ class PlaceholderVisitor(object):
222
+ def __init__(self): self.parameters = []
223
+
224
+ def primitive(self, e):
225
+ if e.name == 'REAL':
226
+ placeholder = Placeholder.named("REAL_", random.random())
227
+ self.parameters.append(placeholder)
228
+ return Primitive(e.name, e.tp, placeholder)
229
+ return e
230
+
231
+ def invented(self, e): return e.body.visit(self)
232
+
233
+ def abstraction(self, e): return Abstraction(e.body.visit(self))
234
+
235
+ def application(self, e):
236
+ return Application(e.f.visit(self), e.x.visit(self))
237
+
238
+ def index(self, e): return e
239
+
240
+ @staticmethod
241
+ def execute(e):
242
+ v = PlaceholderVisitor()
243
+ e = e.visit(v)
244
+ return e, v.parameters
dreamcoder/taskBatcher.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dreamcoder.utilities import eprint
2
+ import random
3
+
4
+
5
+ class DefaultTaskBatcher:
6
+ """Iterates through task batches of the specified size. Defaults to all tasks if taskBatchSize is None."""
7
+
8
+ def __init__(self):
9
+ pass
10
+
11
+ def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
12
+ if taskBatchSize is None:
13
+ taskBatchSize = len(tasks)
14
+ elif taskBatchSize > len(tasks):
15
+ eprint("Task batch size is greater than total number of tasks, aborting.")
16
+ assert False
17
+
18
+
19
+ start = (taskBatchSize * currIteration) % len(tasks)
20
+ end = start + taskBatchSize
21
+ taskBatch = (tasks + tasks)[start:end] # Handle wraparound.
22
+ return taskBatch
23
+
24
+ class RandomTaskBatcher:
25
+ """Returns a randomly sampled task batch of the specified size. Defaults to all tasks if taskBatchSize is None."""
26
+
27
+ def __init__(self):
28
+ pass
29
+
30
+ def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
31
+ if taskBatchSize is None:
32
+ taskBatchSize = len(tasks)
33
+ elif taskBatchSize > len(tasks):
34
+ eprint("Task batch size is greater than total number of tasks, aborting.")
35
+ assert False
36
+
37
+ return random.sample(tasks, taskBatchSize)
38
+
39
+ class RandomShuffleTaskBatcher:
40
+ """Randomly shuffles the task batch first, and then iterates through task batches of the specified size like DefaultTaskBatcher.
41
+ Reshuffles across iterations - intended as benchmark comparison to test the task ordering."""
42
+ def __init__(self, baseSeed=0): self.baseSeed = baseSeed
43
+
44
+ def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
45
+ if taskBatchSize is None:
46
+ taskBatchSize = len(tasks)
47
+ elif taskBatchSize > len(tasks):
48
+ eprint("Task batch size is greater than total number of tasks, aborting.")
49
+ assert False
50
+
51
+ # Reshuffles tasks in a fixed way across epochs for reproducibility.
52
+ currEpoch = int(int(currIteration * taskBatchSize) / int(len(tasks)))
53
+
54
+ shuffledTasks = tasks.copy() # Since shuffle works in place.
55
+ random.Random(self.baseSeed + currEpoch).shuffle(shuffledTasks)
56
+
57
+ shuffledTasksWrap = tasks.copy() # Since shuffle works in place.
58
+ random.Random(self.baseSeed + currEpoch + 1).shuffle(shuffledTasksWrap)
59
+
60
+ start = (taskBatchSize * currIteration) % len(shuffledTasks)
61
+ end = start + taskBatchSize
62
+ taskBatch = (shuffledTasks + shuffledTasksWrap)[start:end] # Wraparound nicely.
63
+
64
+ return list(set(taskBatch))
65
+
66
+ class UnsolvedTaskBatcher:
67
+ """At a given epoch, returns only batches of the tasks that have not been solved at least twice"""
68
+
69
+ def __init__(self):
70
+ self.timesSolved = {} # map from task to times that we have solved it
71
+ self.start = 0
72
+
73
+ def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
74
+ assert taskBatchSize is None, "This batching strategy does not support batch sizes"
75
+
76
+ for t,f in ec_result.allFrontiers.items():
77
+ if f.empty:
78
+ self.timesSolved[t] = max(0, self.timesSolved.get(t,0))
79
+ else:
80
+ self.timesSolved[t] = 1 + self.timesSolved.get(t, 0)
81
+ return [t for t in tasks if self.timesSolved.get(t,0) < 2 ]
82
+
83
+ def entropyRandomBatch(ec_result, tasks, taskBatchSize, randomRatio):
84
+ numRandom = int(randomRatio * taskBatchSize)
85
+ numEntropy = taskBatchSize - numRandom
86
+
87
+ eprint("Selecting top %d tasks from the %d overall tasks given lowest entropy." % (taskBatchSize, len(tasks)))
88
+ eprint("Will be selecting %d by lowest entropy and %d randomly." %(numEntropy, numRandom))
89
+ taskGrammarEntropies = ec_result.recognitionModel.taskGrammarEntropies(tasks)
90
+ sortedEntropies = sorted(taskGrammarEntropies.items(), key=lambda x:x[1])
91
+
92
+ entropyBatch = [task for (task, entropy) in sortedEntropies[:numEntropy]]
93
+ randomBatch = random.sample([task for (task, entropy) in sortedEntropies[numEntropy:]], numRandom)
94
+ batch = entropyBatch + randomBatch
95
+
96
+ return batch
97
+
98
+ def kNearestNeighbors(ec_result, tasks, k, task):
99
+ """Finds the k nearest neighbors in the recognition model logProduction space to a given task."""
100
+ import numpy as np
101
+ cosDistance = ec_result.recognitionModel.grammarLogProductionDistanceToTask(task, tasks)
102
+ argSort = np.argsort(-cosDistance)# Want the greatest similarity.
103
+ topK = argSort[:k]
104
+ topKTasks = list(np.array(tasks)[topK])
105
+ return topKTasks
106
+
107
+
108
+ class RandomkNNTaskBatcher:
109
+ """Chooses a random task and finds the (taskBatchSize - 1) nearest neighbors using the recognition model logits."""
110
+ def __init__(self):
111
+ pass
112
+
113
+ def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
114
+ if taskBatchSize is None:
115
+ taskBatchSize = len(tasks)
116
+ elif taskBatchSize > len(tasks):
117
+ eprint("Task batch size is greater than total number of tasks, aborting.")
118
+ assert False
119
+
120
+ if ec_result.recognitionModel is None:
121
+ eprint("No recognition model, falling back on random %d" % taskBatchSize)
122
+ return random.sample(tasks, taskBatchSize)
123
+ else:
124
+ randomTask = random.choice(tasks)
125
+ kNN = kNearestNeighbors(ec_result, tasks, taskBatchSize - 1, randomTask)
126
+ return [randomTask] + kNN
127
+
128
+ class RandomLowEntropykNNTaskBatcher:
129
+ """Choose a random task from the 10 unsolved with the lowest entropy, and finds the (taskBatchSize - 1) nearest neighbors using the recognition model logits."""
130
+ def __init__(self):
131
+ pass
132
+
133
+ def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
134
+ unsolvedTasks = [t for t in tasks if ec_result.allFrontiers[t].empty]
135
+
136
+ if taskBatchSize is None:
137
+ return unsolvedTasks
138
+ elif taskBatchSize > len(tasks):
139
+ eprint("Task batch size is greater than total number of tasks, aborting.")
140
+ assert False
141
+
142
+ if ec_result.recognitionModel is None:
143
+ eprint("No recognition model, falling back on random %d tasks from the remaining %d" %(taskBatchSize, len(unsolvedTasks)))
144
+ return random.sample(unsolvedTasks, taskBatchSize)
145
+ else:
146
+ lowEntropyUnsolved = entropyRandomBatch(ec_result, unsolvedTasks, taskBatchSize, randomRatio=0)
147
+ randomTask = random.choice(lowEntropyUnsolved)
148
+ kNN = kNearestNeighbors(ec_result, tasks, taskBatchSize - 1, randomTask)
149
+ return [randomTask] + kNN
150
+
151
+
152
+ class UnsolvedEntropyTaskBatcher:
153
+ """Returns tasks that have never been solved at any previous iteration.
154
+ Given a task batch size, returns the unsolved tasks with the lowest entropy."""
155
+ def __init__(self):
156
+ pass
157
+
158
+ def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
159
+ unsolvedTasks = [t for t in tasks if ec_result.allFrontiers[t].empty]
160
+
161
+ if taskBatchSize is None:
162
+ return unsolvedTasks
163
+ elif taskBatchSize > len(tasks):
164
+ eprint("Task batch size is greater than total number of tasks, aborting.")
165
+ assert False
166
+
167
+ if ec_result.recognitionModel is None:
168
+ eprint("No recognition model, falling back on random %d tasks from the remaining %d" %(taskBatchSize, len(unsolvedTasks)))
169
+ return random.sample(unsolvedTasks, taskBatchSize)
170
+ else:
171
+ return entropyRandomBatch(ec_result, unsolvedTasks, taskBatchSize, randomRatio=0)
172
+
173
+ class UnsolvedRandomEntropyTaskBatcher:
174
+ """Returns tasks that have never been solved at any previous iteration.
175
+ Given a task batch size, returns a mix of unsolved tasks with percentRandom
176
+ selected randomly and the remaining selected by lowest entropy."""
177
+ def __init__(self):
178
+ pass
179
+
180
+ def getTaskBatch(self, ec_result, tasks, taskBatchSize, currIteration):
181
+ unsolvedTasks = [t for t in tasks if ec_result.allFrontiers[t].empty]
182
+
183
+ if taskBatchSize is None:
184
+ return unsolvedTasks
185
+ elif taskBatchSize > len(tasks):
186
+ eprint("Task batch size is greater than total number of tasks, aborting.")
187
+ assert False
188
+
189
+ if ec_result.recognitionModel is None:
190
+ eprint("No recognition model, falling back on random %d tasks from the remaining %d" %(taskBatchSize, len(unsolvedTasks)))
191
+ return random.sample(unsolvedTasks, taskBatchSize)
192
+ else:
193
+ return entropyRandomBatch(ec_result, unsolvedTasks, taskBatchSize, randomRatio=.5)
194
+
195
+
196
+
197
+
198
+
199
+
200
+
dreamcoder/type.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class UnificationFailure(Exception):
2
+ pass
3
+
4
+
5
+ class Occurs(UnificationFailure):
6
+ pass
7
+
8
+
9
+ class Type(object):
10
+ def __str__(self): return self.show(True)
11
+
12
+ def __repr__(self): return str(self)
13
+
14
+ @staticmethod
15
+ def fromjson(j):
16
+ if "index" in j: return TypeVariable(j["index"])
17
+ if "constructor" in j: return TypeConstructor(j["constructor"],
18
+ [ Type.fromjson(a) for a in j["arguments"] ])
19
+ assert False
20
+
21
+
22
+ class TypeConstructor(Type):
23
+ def __init__(self, name, arguments):
24
+ self.name = name
25
+ self.arguments = arguments
26
+ self.isPolymorphic = any(a.isPolymorphic for a in arguments)
27
+
28
+ def makeDummyMonomorphic(self, mapping=None):
29
+ mapping = mapping if mapping is not None else {}
30
+ return TypeConstructor(self.name,
31
+ [ a.makeDummyMonomorphic(mapping) for a in self.arguments ])
32
+
33
+ def __eq__(self, other):
34
+ return isinstance(other, TypeConstructor) and \
35
+ self.name == other.name and \
36
+ all(x == y for x, y in zip(self.arguments, other.arguments))
37
+
38
+ def __hash__(self): return hash((self.name,) + tuple(self.arguments))
39
+
40
+ def __ne__(self, other):
41
+ return not (self == other)
42
+
43
+ def show(self, isReturn):
44
+ if self.name == ARROW:
45
+ if isReturn:
46
+ return "%s %s %s" % (self.arguments[0].show(
47
+ False), ARROW, self.arguments[1].show(True))
48
+ else:
49
+ return "(%s %s %s)" % (self.arguments[0].show(
50
+ False), ARROW, self.arguments[1].show(True))
51
+ elif self.arguments == []:
52
+ return self.name
53
+ else:
54
+ return "%s(%s)" % (self.name, ", ".join(x.show(True)
55
+ for x in self.arguments))
56
+
57
+ def json(self):
58
+ return {"constructor": self.name,
59
+ "arguments": [a.json() for a in self.arguments]}
60
+
61
+
62
+ def isArrow(self): return self.name == ARROW
63
+
64
+ def functionArguments(self):
65
+ if self.name == ARROW:
66
+ xs = self.arguments[1].functionArguments()
67
+ return [self.arguments[0]] + xs
68
+ return []
69
+
70
+ def returns(self):
71
+ if self.name == ARROW:
72
+ return self.arguments[1].returns()
73
+ else:
74
+ return self
75
+
76
+ def apply(self, context):
77
+ if not self.isPolymorphic:
78
+ return self
79
+ return TypeConstructor(self.name,
80
+ [x.apply(context) for x in self.arguments])
81
+
82
+ def applyMutable(self, context):
83
+ if not self.isPolymorphic:
84
+ return self
85
+ return TypeConstructor(self.name,
86
+ [x.applyMutable(context) for x in self.arguments])
87
+
88
+ def occurs(self, v):
89
+ if not self.isPolymorphic:
90
+ return False
91
+ return any(x.occurs(v) for x in self.arguments)
92
+
93
+ def negateVariables(self):
94
+ return TypeConstructor(self.name,
95
+ [a.negateVariables() for a in self.arguments])
96
+
97
+ def instantiate(self, context, bindings=None):
98
+ if not self.isPolymorphic:
99
+ return context, self
100
+ if bindings is None:
101
+ bindings = {}
102
+ newArguments = []
103
+ for x in self.arguments:
104
+ (context, x) = x.instantiate(context, bindings)
105
+ newArguments.append(x)
106
+ return (context, TypeConstructor(self.name, newArguments))
107
+
108
+ def instantiateMutable(self, context, bindings=None):
109
+ if not self.isPolymorphic:
110
+ return self
111
+ if bindings is None:
112
+ bindings = {}
113
+ newArguments = []
114
+ return TypeConstructor(self.name, [x.instantiateMutable(context, bindings)
115
+ for x in self.arguments ])
116
+
117
+
118
+ def canonical(self, bindings=None):
119
+ if not self.isPolymorphic:
120
+ return self
121
+ if bindings is None:
122
+ bindings = {}
123
+ return TypeConstructor(self.name,
124
+ [x.canonical(bindings) for x in self.arguments])
125
+
126
+
127
+ class TypeVariable(Type):
128
+ def __init__(self, j):
129
+ assert isinstance(j, int)
130
+ self.v = j
131
+ self.isPolymorphic = True
132
+
133
+ def makeDummyMonomorphic(self, mapping=None):
134
+ mapping = mapping if mapping is not None else {}
135
+ if self.v not in mapping:
136
+ mapping[self.v] = TypeConstructor(f"dummy_type_{len(mapping)}", [])
137
+ return mapping[self.v]
138
+
139
+
140
+ def __eq__(self, other):
141
+ return isinstance(other, TypeVariable) and self.v == other.v
142
+
143
+ def __ne__(self, other): return not (self.v == other.v)
144
+
145
+ def __hash__(self): return self.v
146
+
147
+ def show(self, _): return "t%d" % self.v
148
+
149
+ def json(self):
150
+ return {"index": self.v}
151
+
152
+ def returns(self): return self
153
+
154
+ def isArrow(self): return False
155
+
156
+ def functionArguments(self): return []
157
+
158
+ def apply(self, context):
159
+ for v, t in context.substitution:
160
+ if v == self.v:
161
+ return t.apply(context)
162
+ return self
163
+
164
+ def applyMutable(self, context):
165
+ s = context.substitution[self.v]
166
+ if s is None: return self
167
+ new = s.applyMutable(context)
168
+ context.substitution[self.v] = new
169
+ return new
170
+
171
+ def occurs(self, v): return v == self.v
172
+
173
+ def instantiate(self, context, bindings=None):
174
+ if bindings is None:
175
+ bindings = {}
176
+ if self.v in bindings:
177
+ return (context, bindings[self.v])
178
+ new = TypeVariable(context.nextVariable)
179
+ bindings[self.v] = new
180
+ context = Context(context.nextVariable + 1, context.substitution)
181
+ return (context, new)
182
+
183
+ def instantiateMutable(self, context, bindings=None):
184
+ if bindings is None: bindings = {}
185
+ if self.v in bindings: return bindings[self.v]
186
+ new = context.makeVariable()
187
+ bindings[self.v] = new
188
+ return new
189
+
190
+ def canonical(self, bindings=None):
191
+ if bindings is None:
192
+ bindings = {}
193
+ if self.v in bindings:
194
+ return bindings[self.v]
195
+ new = TypeVariable(len(bindings))
196
+ bindings[self.v] = new
197
+ return new
198
+
199
+ def negateVariables(self):
200
+ return TypeVariable(-1 - self.v)
201
+
202
+
203
+ class Context(object):
204
+ def __init__(self, nextVariable=0, substitution=[]):
205
+ self.nextVariable = nextVariable
206
+ self.substitution = substitution
207
+
208
+ def extend(self, j, t):
209
+ return Context(self.nextVariable, [(j, t)] + self.substitution)
210
+
211
+ def makeVariable(self):
212
+ return (Context(self.nextVariable + 1, self.substitution),
213
+ TypeVariable(self.nextVariable))
214
+
215
+ def unify(self, t1, t2):
216
+ t1 = t1.apply(self)
217
+ t2 = t2.apply(self)
218
+ if t1 == t2:
219
+ return self
220
+ # t1&t2 are not equal
221
+ if not t1.isPolymorphic and not t2.isPolymorphic:
222
+ raise UnificationFailure(t1, t2)
223
+
224
+ if isinstance(t1, TypeVariable):
225
+ if t2.occurs(t1.v):
226
+ raise Occurs()
227
+ return self.extend(t1.v, t2)
228
+ if isinstance(t2, TypeVariable):
229
+ if t1.occurs(t2.v):
230
+ raise Occurs()
231
+ return self.extend(t2.v, t1)
232
+ if t1.name != t2.name:
233
+ raise UnificationFailure(t1, t2)
234
+ k = self
235
+ for x, y in zip(t2.arguments, t1.arguments):
236
+ k = k.unify(x, y)
237
+ return k
238
+
239
+ def __str__(self):
240
+ return "Context(next = %d, {%s})" % (self.nextVariable, ", ".join(
241
+ "t%d ||> %s" % (k, v.apply(self)) for k, v in self.substitution))
242
+
243
+ def __repr__(self): return str(self)
244
+
245
+ class MutableContext(object):
246
+ def __init__(self):
247
+ self.substitution = []
248
+
249
+ def extend(self,i,t):
250
+ assert self.substitution[i] is None
251
+ self.substitution[i] = t
252
+
253
+ def makeVariable(self):
254
+ self.substitution.append(None)
255
+ return TypeVariable(len(self.substitution) - 1)
256
+
257
+ def unify(self, t1, t2):
258
+ t1 = t1.applyMutable(self)
259
+ t2 = t2.applyMutable(self)
260
+
261
+ if t1 == t2: return
262
+
263
+ # t1&t2 are not equal
264
+ if not t1.isPolymorphic and not t2.isPolymorphic:
265
+ raise UnificationFailure(t1, t2)
266
+
267
+ if isinstance(t1, TypeVariable):
268
+ if t2.occurs(t1.v):
269
+ raise Occurs()
270
+ self.extend(t1.v, t2)
271
+ return
272
+ if isinstance(t2, TypeVariable):
273
+ if t1.occurs(t2.v):
274
+ raise Occurs()
275
+ self.extend(t2.v, t1)
276
+ return
277
+ if t1.name != t2.name:
278
+ raise UnificationFailure(t1, t2)
279
+
280
+ for x, y in zip(t2.arguments, t1.arguments):
281
+ self.unify(x, y)
282
+
283
+
284
+ Context.EMPTY = Context(0, [])
285
+
286
+
287
+ def canonicalTypes(ts):
288
+ bindings = {}
289
+ return [t.canonical(bindings) for t in ts]
290
+
291
+
292
+ def instantiateTypes(context, ts):
293
+ bindings = {}
294
+ newTypes = []
295
+ for t in ts:
296
+ context, t = t.instantiate(context, bindings)
297
+ newTypes.append(t)
298
+ return context, newTypes
299
+
300
+
301
+ def baseType(n): return TypeConstructor(n, [])
302
+
303
+
304
+ tint = baseType("int")
305
+ treal = baseType("real")
306
+ tbool = baseType("bool")
307
+ tboolean = tbool # alias
308
+ tcharacter = baseType("char")
309
+
310
+
311
+ def tlist(t): return TypeConstructor("list", [t])
312
+
313
+
314
+ def tpair(a, b): return TypeConstructor("pair", [a, b])
315
+
316
+
317
+ def tmaybe(t): return TypeConstructor("maybe", [t])
318
+
319
+
320
+ tstr = tlist(tcharacter)
321
+ t0 = TypeVariable(0)
322
+ t1 = TypeVariable(1)
323
+ t2 = TypeVariable(2)
324
+
325
+ # regex types
326
+ tpregex = baseType("pregex")
327
+
328
+ ARROW = "->"
329
+
330
+
331
+ def arrow(*arguments):
332
+ if len(arguments) == 1:
333
+ return arguments[0]
334
+ return TypeConstructor(ARROW, [arguments[0], arrow(*arguments[1:])])
335
+
336
+
337
+ def inferArg(tp, tcaller):
338
+ ctx, tp = tp.instantiate(Context.EMPTY)
339
+ ctx, tcaller = tcaller.instantiate(ctx)
340
+ ctx, targ = ctx.makeVariable()
341
+ ctx = ctx.unify(tcaller, arrow(targ, tp))
342
+ return targ.apply(ctx)
343
+
344
+
345
+ def guess_type(xs):
346
+ """
347
+ Return a TypeConstructor corresponding to x's python type.
348
+ Raises an exception if the type cannot be guessed.
349
+ """
350
+ if all(isinstance(x, bool) for x in xs):
351
+ return tbool
352
+ elif all(isinstance(x, int) for x in xs):
353
+ return tint
354
+ elif all(isinstance(x, str) for x in xs):
355
+ return tstr
356
+ elif all(isinstance(x, list) for x in xs):
357
+ return tlist(guess_type([y for ys in xs for y in ys]))
358
+ else:
359
+ raise ValueError("cannot guess type from {}".format(xs))
360
+
361
+
362
+ def guess_arrow_type(examples):
363
+ a = len(examples[0][0])
364
+ input_types = []
365
+ for n in range(a):
366
+ input_types.append(guess_type([xs[n] for xs, _ in examples]))
367
+ output_type = guess_type([y for _, y in examples])
368
+ return arrow(*(input_types + [output_type]))
369
+
370
+ def canUnify(t1, t2):
371
+ k = MutableContext()
372
+ t1 = t1.instantiateMutable(k)
373
+ t2 = t2.instantiateMutable(k)
374
+ try:
375
+ k.unify(t1, t2)
376
+ return True
377
+ except UnificationFailure: return False
378
+