Spaces:
No application file
No application file
# stuff specifically for the sklearn logic | |
from typing import Mapping | |
from functools import partial, reduce | |
import operator | |
from itertools import product | |
import argparse | |
############################################################################### | |
# A grid search convenience class | |
############################################################################### | |
class ParameterGrid: | |
"""logic YOINKED from sklearn <3 | |
def worth just using the lib itself, or something fancier in future for | |
efficient sampling etc. It's implemented as an iterator interface but thats | |
probs not necessary""" | |
def __init__(self, params): | |
# we may want to product a few sets of parameters | |
# independently of eachother, so expects a List[Mapping] | |
if isinstance(params, Mapping): | |
self.params = [params] | |
else: | |
self.params = params | |
# removed all checking code soooo make sure your | |
# param dict is already nice and conforming | |
def __iter__(self): | |
"""Iterate over the points in the grid. | |
Returns | |
------- | |
params : iterator over dict of str to any | |
Yields dictionaries mapping each estimator parameter to one of its | |
allowed values. | |
""" | |
for p in self.params: | |
# Always sort the keys of a dictionary, for reproducibility | |
items = sorted(p.items()) | |
if not items: | |
yield {} | |
else: | |
keys, values = zip(*items) | |
for v in product(*values): | |
params = dict(zip(keys, v)) | |
yield params | |
def __len__(self): | |
"""Number of points on the grid.""" | |
# Product function that can handle iterables (np.product can't). | |
product = partial(reduce, operator.mul) | |
return sum(product(len(v) for v in p.values()) if p else 1 for p in self.params) | |
############################################################################### | |
# little "oneliner" reduce thingy that turns your shallow dict into | |
# the list [k1, v1, k2, v2, k3, v3 ...] | |
# and optionally "k1 v1 k2 v2 k3 v3" | |
def flatten_dict(dict, to_string=False, sep=" "): | |
flat_dict = reduce(operator.iconcat, dict.items(), []) | |
if to_string: | |
try: | |
return sep.join([str(elm) for elm in flat_dict]) | |
except: | |
raise ValueError(f"Error converting dict={flat_dict} to whitespace joined string") | |
else: | |
return flat_dict | |
def str2bool(v): | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ("yes", "true", "t", "y", "1"): | |
return True | |
elif v.lower() in ("no", "false", "f", "n", "0"): | |
return False | |
else: | |
raise argparse.ArgumentTypeError("Boolean value expected.") | |