|
from typing import List, Union |
|
from copy import deepcopy |
|
from collections import namedtuple |
|
from pathlib import Path |
|
import argparse |
|
from argparse import RawDescriptionHelpFormatter |
|
import yaml |
|
from pydantic import BaseModel as _Base |
|
import os |
|
|
|
|
|
class BaseConf(_Base): |
|
class Config: |
|
validate_all = True |
|
allow_mutation = True |
|
extra = "ignore" |
|
|
|
|
|
def SingleOrList(inner_type): |
|
return Union[inner_type, List[inner_type]] |
|
|
|
|
|
def optional_load_config(fname="config.yml"): |
|
cfg = {} |
|
conf_fname = Path.cwd() / fname |
|
if conf_fname.is_file(): |
|
with conf_fname.open("r") as f: |
|
raw = f.read() |
|
print("loaded config\n ") |
|
print(raw) |
|
cfg = yaml.safe_load(raw) |
|
return cfg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_full_config(cfg_obj, fname="full_config.yml"): |
|
cfg = cfg_obj.dict() |
|
cfg = _dict_to_yaml(cfg) |
|
print(f"\n--- full config ---\n\n{cfg}\n") |
|
with (Path.cwd() / fname).open("w") as f: |
|
f.write(cfg) |
|
|
|
|
|
def argparse_cfg_template(curr_cfgs): |
|
parser = argparse.ArgumentParser( |
|
description='Manual spec of configs', |
|
epilog=f'curr cfgs:\n\n{_dict_to_yaml(curr_cfgs)}', |
|
formatter_class=RawDescriptionHelpFormatter |
|
) |
|
_, args = parser.parse_known_args() |
|
clauses = [] |
|
for i in range(0, len(args), 2): |
|
assert args[i][:2] == "--", "please start args with --" |
|
clauses.append({args[i][2:]: args[i+1]}) |
|
|
|
|
|
maker = ConfigMaker(curr_cfgs) |
|
for clu in clauses: |
|
maker.execute_clause(clu) |
|
|
|
final = maker.state.copy() |
|
return final |
|
|
|
|
|
def _dict_to_yaml(arg): |
|
return yaml.safe_dump(arg, sort_keys=False, allow_unicode=True) |
|
|
|
|
|
def dispatch(module): |
|
cfg = optional_load_config(fname="gradio_init.yml") |
|
cfg = module(**cfg).dict() |
|
|
|
cfg = argparse_cfg_template(cfg) |
|
mod = module(**cfg) |
|
|
|
exp_path = os.path.join(cfg['exp_dir'],cfg['initial']) |
|
os.makedirs(exp_path, exist_ok=True) |
|
write_full_config(mod, os.path.join(exp_path,"full_config.yml")) |
|
|
|
mod.run() |
|
|
|
def dispatch_gradio(module, prompt, keyword, ti_step, pt_step, seed): |
|
cfg = optional_load_config("gradio_init.yml") |
|
|
|
cfg['sd']['prompt'] = prompt |
|
cfg['sd']['dir'] = os.path.join(cfg['exp_dir'],keyword,'lora/final_lora.safetensors') |
|
cfg['ti_step'] = ti_step |
|
cfg['pt_step'] = pt_step |
|
cfg['initial'] = keyword |
|
cfg['random_seed'] = seed |
|
|
|
cfg = module(**cfg).dict() |
|
mod = module(**cfg) |
|
|
|
|
|
return mod |
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigMaker(): |
|
CMD = namedtuple('cmd', field_names=['sub', 'verb', 'objs']) |
|
VERBS = ('add', 'replace', 'del') |
|
|
|
def __init__(self, base_node): |
|
self.state = base_node |
|
self.clauses = [] |
|
|
|
def clone(self): |
|
return deepcopy(self) |
|
|
|
def execute_clause(self, raw_clause): |
|
cls = self.__class__ |
|
assert isinstance(raw_clause, (str, dict)) |
|
if isinstance(raw_clause, dict): |
|
assert len(raw_clause) == 1, \ |
|
"a clause can only have 1 statement: {} clauses in {}".format( |
|
len(raw_clause), raw_clause |
|
) |
|
cmd = list(raw_clause.keys())[0] |
|
arg = raw_clause[cmd] |
|
else: |
|
cmd = raw_clause |
|
arg = None |
|
cmd = self.parse_clause_cmd(cmd) |
|
tracer = NodeTracer(self.state) |
|
tracer.advance_pointer(path=cmd.sub) |
|
if cmd.verb == cls.VERBS[0]: |
|
tracer.add(cmd.objs, arg) |
|
elif cmd.verb == cls.VERBS[1]: |
|
tracer.replace(cmd.objs, arg) |
|
elif cmd.verb == cls.VERBS[2]: |
|
assert isinstance(raw_clause, str) |
|
tracer.delete(cmd.objs) |
|
self.state = tracer.state |
|
|
|
@classmethod |
|
def parse_clause_cmd(cls, input): |
|
""" |
|
Args: |
|
input: a string to be parsed |
|
1. First test whether a verb is present |
|
2. If not present, then str is a single subject, and verb is replace |
|
This is a syntactical sugar that makes writing config easy |
|
3. If a verb is found, whatever comes before is a subject, and after the |
|
objects. |
|
4. Handle the edge cases properly. Below are expected parse outputs |
|
input sub verb obj |
|
--- No verb |
|
'' '' replace [] |
|
'a.b' 'a.b' replace [] |
|
'add' '' add [] |
|
'P Q' err: 2 subjects |
|
--- Verb present |
|
'T add' 'T' add [] |
|
'T del a b' 'T' del [a, b] |
|
'P Q add a' err: 2 subjects |
|
'P add del b' err: 2 verbs |
|
""" |
|
assert isinstance(input, str) |
|
input = input.split() |
|
objs = [] |
|
sub = '' |
|
verb, verb_inx = cls.scan_for_verb(input) |
|
if verb is None: |
|
assert len(input) <= 1, "no verb present; more than 1 subject: {}"\ |
|
.format(input) |
|
sub = input[0] if len(input) == 1 else '' |
|
verb = cls.VERBS[1] |
|
else: |
|
assert not verb_inx > 1, 'verb {} at inx {}; more than 1 subject in: {}'\ |
|
.format(verb, verb_inx, input) |
|
sub = input[0] if verb_inx == 1 else '' |
|
objs = input[verb_inx + 1:] |
|
cmd = cls.CMD(sub=sub, verb=verb, objs=objs) |
|
return cmd |
|
|
|
@classmethod |
|
def scan_for_verb(cls, input_list): |
|
assert isinstance(input_list, list) |
|
counts = [ input_list.count(v) for v in cls.VERBS ] |
|
presence = [ cnt > 0 for cnt in counts ] |
|
if sum(presence) == 0: |
|
return None, -1 |
|
elif sum(presence) > 1: |
|
raise ValueError("multiple verbs discovered in {}".format(input_list)) |
|
|
|
if max(counts) > 1: |
|
raise ValueError("verbs repeated in cmd: {}".format(input_list)) |
|
|
|
verb = cls.VERBS[presence.index(1)] |
|
inx = input_list.index(verb) |
|
return verb, inx |
|
|
|
|
|
class NodeTracer(): |
|
def __init__(self, src_node): |
|
""" |
|
A src node can be either a list or dict |
|
""" |
|
assert isinstance(src_node, (list, dict)) |
|
|
|
|
|
self.child_token = "_" |
|
self.parent = {self.child_token: src_node} |
|
|
|
|
|
self.root_child_token = self.child_token |
|
self.root = self.parent |
|
|
|
@property |
|
def state(self): |
|
return self.root[self.root_child_token] |
|
|
|
@property |
|
def pointed(self): |
|
return self.parent[self.child_token] |
|
|
|
def advance_pointer(self, path): |
|
if len(path) == 0: |
|
return |
|
path_list = list( |
|
map(lambda x: int(x) if str.isdigit(x) else x, path.split('.')) |
|
) |
|
|
|
for i, token in enumerate(path_list): |
|
self.parent = self.pointed |
|
self.child_token = token |
|
try: |
|
self.pointed |
|
except (IndexError, KeyError): |
|
raise ValueError( |
|
"During the tracing of {}, {}-th token '{}'" |
|
" is not present in node {}".format( |
|
path, i, self.child_token, self.state |
|
) |
|
) |
|
|
|
def replace(self, objs, arg): |
|
assert len(objs) == 0 |
|
val_type = type(self.parent[self.child_token]) |
|
|
|
|
|
|
|
arg = str(arg) |
|
if val_type == str: |
|
pass |
|
else: |
|
arg = eval(arg) |
|
assert type(arg) == val_type, \ |
|
f"require {val_type.__name__}, given {type(arg).__name__}" |
|
|
|
self.parent[self.child_token] = arg |
|
|