grg's picture
Parameter selection added
f397ead
from abc import ABC
from graphviz import Digraph
import re
import random
from termcolor import cprint
from collections import defaultdict
class Node:
def __init__(self, id, label, type, parent=None):
"""
type: type must be "param" or "value"
for type "param" one of the children must be chosen
for type "value" all of the children must be set
"""
self.id = id
self.label = label
self.parent = parent
self.children = []
self.type = type
# calculate node's level
parent_ = self.parent
self.level = 1
while parent_ is not None:
self.level += 1
parent_ = parent_.parent
def __repr__(self):
return f"{self.id}({self.type})-'{self.label}'"
class ParameterTree(ABC):
def __init__(self):
self.last_node_id = 0
self.create_digraph()
self.nodes = {}
self.root = None
def create_digraph(self):
self.tree = Digraph("unix", format='svg')
self.tree.attr(size='30,100')
def get_node_for_id(self, id):
return self.nodes[id]
def add_node(self, label, parent=None, type="param"):
"""
All children of this node must be set
"""
if type not in ["param", "value"]:
raise ValueError('Node type must be "param" or "value"')
if parent is None and self.root is not None:
raise ValueError("Root already set: {}. parent cannot be None. ".format(self.root.id))
# add to graph
node_id = self.new_node_id()
self.nodes[node_id] = Node(id=node_id, label=label, parent=parent, type=type)
if parent is None:
self.root = self.nodes[node_id]
else:
self.nodes[parent.id].children.append(self.nodes[node_id])
return self.nodes[node_id]
def sample_env_params(self, ACL=None):
parameters = {}
nodes = [self.root]
# BFS
while nodes:
node = nodes.pop(0)
if node.type == "param":
if len(node.children) == 0:
raise ValueError("Node {} doesn't have any children.".format(node.label))
if ACL is None:
# choose randomly
chosen = random.choice(node.children)
else:
# let the ACL choose
chosen = ACL.choose(node, parameters)
assert chosen.type == "value"
nodes.append(chosen)
parameters[node] = chosen
elif node.type == "value":
nodes.extend(node.children)
else:
raise ValueError('Node type must be "param" or "value" and is {}'.format(node.type))
return parameters
def new_node_id(self):
new_id = self.last_node_id + 1
self.last_node_id = new_id
return str("node_"+str(new_id))
def print_tree(self, selected_parameters={}):
print("Parameter tree")
nodes = [self.root]
color = None
# BFS
while nodes:
node = nodes.pop(0)
if node.type == "param":
if node in selected_parameters.keys():
color = "blue"
else:
color = None
if node.parent is not None:
cprint("{}: {} ({}) -----> {}: {} ({})".format(
node.parent.type, node.parent.label, node.parent.id,
node.type, node.label, node.id
), color)
else:
cprint("{}: {} ({})".format(node.type, node.label, node.id), color)
nodes.extend(node.children)
def get_all_params(self):
all_params = defaultdict(list)
nodes = [self.root]
while nodes:
node = nodes.pop(0)
if node.type == "value":
all_params[node.parent].append(node)
nodes.extend(node.children)
return all_params
def draw_tree(self, filename, selected_parameters={}, ignore_labels=[], folded_nodes=[], label_parser={}, save=True):
self.create_digraph()
nodes = [self.root]
color_param = "grey60"
color_value = "lightgray"
fontcolor = "black"
fontsize = "18"
dots_fontsize = "30"
folded_param = "grey95"
folded_value = "grey95"
folded_fontcolor = "gray70"
def add_fold_symbol(label, folded=False):
return label
# return label + " ❯" if folded else label
# BFS - construct vizgraph
while nodes:
node = nodes.pop(0)
while node.label in ignore_labels:
node = nodes.pop(0)
if node.label in folded_nodes:
n_label = label_parser.get(node.label, node.label)
n_label = add_fold_symbol(n_label, folded=True)
if node.type == "param":
color = folded_param
self.tree.attr('node', shape='box', style="filled", color=color, fontcolor=folded_fontcolor, fontsize=fontsize)
self.tree.node(name=node.id, label=n_label, type="parameter")
elif node.type == "value":
color = folded_value
self.tree.attr('node', shape='ellipse', style='filled', color=color, fontcolor=folded_fontcolor, fontsize=fontsize)
self.tree.node(name=node.id, label=n_label, type="value")
else:
raise ValueError(f"Undefined node type {node.type}")
# add folded node sign
folded_node_id = node.id+"_fold"
# self.tree.attr('node', shape='ellipse', style='filled', color="white", fontcolor=folded_fontcolor, fontsize=fontsize)
# self.tree.attr('node', shape='none', style='filled', color="gray", fontcolor=folded_fontcolor, fontsize=dots_fontsize)
self.tree.attr('node', shape='none', color="white",fontcolor=folded_fontcolor, fontsize=dots_fontsize)
self.tree.node(name=folded_node_id, label="...", type="value")
self.tree.edge(node.id, folded_node_id, color=folded_fontcolor)
elif node.type == "param":
if node.label in selected_parameters.keys() and (node == self.root or node.parent.selected):
color = "lightblue3"
node.selected=True
else:
color = color_param
node.selected=False
n_label = label_parser.get(node.label, node.label)
n_label = add_fold_symbol(n_label, folded=False)
self.tree.attr('node', shape='box', style="filled", color=color, fontcolor=fontcolor, fontsize=fontsize)
self.tree.node(name=node.id, label=n_label, type="parameter")
nodes.extend(node.children)
elif node.type == "value":
if (selected_parameters.get(node.parent.label, "Not existent") == node.label) and (node == self.root or node.parent.selected):
# if node.label in selected_parameters.values() and (node == self.root or node.parent.selected):
color = "lightblue2"
node.selected = True
else:
color = color_value
node.selected = False
n_label = label_parser.get(node.label, node.label)
n_label = add_fold_symbol(n_label, folded=False)
# add to vizgraph
self.tree.attr('node', shape='ellipse', style='filled', color=color, fontcolor=fontcolor, fontsize=fontsize)
self.tree.node(name=node.id, label=n_label, type="value")
nodes.extend(node.children)
else:
raise ValueError(f"Undefined node type {node.type}")
if node.parent is not None:
self.tree.edge(node.parent.id, node.id)
# draw image
if save:
self.tree.render(filename)
print("Tree image saved in : {}".format(filename))
if __name__ == '__main__':
# demo of how to use the ParameterTree class
tree = ParameterTree()
env_type_nd = tree.add_node("Env_type", type="param")
inf_seeking_nd = tree.add_node("Information_seeking", parent=env_type_nd, type="value")
collab_nd = tree.add_node("Collaboration", parent=env_type_nd, type="value")
perc_inf_nd = tree.add_node("Perception_inference", parent=env_type_nd, type="value")
raise DeprecationWarning("deprecated parameters")
# Information seeking
scaffolding_nd = tree.add_node("Scaffolding", parent=inf_seeking_nd, type="param")
tree.add_node("lot", parent=scaffolding_nd, type="value")
tree.add_node("medium", parent=scaffolding_nd, type="value")
tree.add_node("little", parent=scaffolding_nd, type="value")
tree.add_node("none", parent=scaffolding_nd, type="value")
prag_fr_compl_nd = tree.add_node("Pragmatic_frame_complexity", parent=inf_seeking_nd, type="param")
tree.add_node("Eye contact", parent=prag_fr_compl_nd, type="value")
tree.add_node("Hello", parent=prag_fr_compl_nd, type="value")
emulation_nd = tree.add_node("Emulation", parent=inf_seeking_nd, type="param")
tree.add_node("N", parent=emulation_nd, type="value")
tree.add_node("Y", parent=emulation_nd, type="value")
pointing_nd = tree.add_node("Pointing", parent=inf_seeking_nd, type="param")
tree.add_node("No", parent=pointing_nd, type="value")
tree.add_node("Direct", parent=pointing_nd, type="value")
tree.add_node("Indirect", parent=pointing_nd, type="value")
language_graounding_nd = tree.add_node("Language_grounding", parent=inf_seeking_nd, type="param")
tree.add_node("No", parent=language_graounding_nd, type="value")
tree.add_node("Color", parent=language_graounding_nd, type="value")
tree.add_node("Feedback", parent=language_graounding_nd, type="value")
problem_nd = tree.add_node("Problem", parent=inf_seeking_nd, type="param")
tree.add_node("Boxes", parent=problem_nd, type="value")
tree.add_node("Switches", parent=problem_nd, type="value")
tree.add_node("Corridors", parent=problem_nd, type="value")
obstacles_nd = tree.add_node("Obstacles", parent=inf_seeking_nd, type="param")
tree.add_node("no", parent=obstacles_nd, type="value")
tree.add_node("lava", parent=obstacles_nd, type="value")
tree.add_node("walls", parent=obstacles_nd, type="value")
# Collaboration
colab_type_nd = tree.add_node("Collaboration type", parent=collab_nd, type="param")
tree.add_node("Door Lever", parent=colab_type_nd, type="value")
tree.add_node("Door Button", parent=colab_type_nd, type="value")
tree.add_node("Marble Run", parent=colab_type_nd, type="value")
tree.add_node("Marble Pass", parent=colab_type_nd, type="value")
role_nd = tree.add_node("Role", parent=collab_nd, type="param")
tree.add_node("A", parent=role_nd, type="value")
tree.add_node("B", parent=role_nd, type="value")
tree.add_node("asocial", parent=role_nd, type="value")
# Perception inference
NPC_movement_nd = tree.add_node("NPC movement", parent=perc_inf_nd, type="param")
tree.add_node("can't turn; can't move", parent=NPC_movement_nd, type="value")
tree.add_node("can turn; can't move", parent=NPC_movement_nd, type="value")
tree.add_node("can turn; can move", parent=NPC_movement_nd, type="value")
occlusion_nd = tree.add_node("Occlusions", parent=perc_inf_nd, type="param")
tree.add_node("no", parent=occlusion_nd, type="value")
tree.add_node("walls", parent=occlusion_nd, type="value")
params = tree.sample_env_params()
tree.draw_tree("viz/demotree", params)