Spaces:
Running
on
L40S
Running
on
L40S
def MakeSmartType(t): | |
if isinstance(t, str): | |
return SmartType(t) | |
return t | |
class SmartType(str): | |
def __ne__(self, other): | |
if self == "*" or other == "*": | |
return False | |
selfset = set(self.split(',')) | |
otherset = set(other.split(',')) | |
return not selfset.issubset(otherset) | |
def VariantSupport(): | |
def decorator(cls): | |
if hasattr(cls, "INPUT_TYPES"): | |
old_input_types = getattr(cls, "INPUT_TYPES") | |
def new_input_types(*args, **kwargs): | |
types = old_input_types(*args, **kwargs) | |
for category in ["required", "optional"]: | |
if category not in types: | |
continue | |
for key, value in types[category].items(): | |
if isinstance(value, tuple): | |
types[category][key] = (MakeSmartType(value[0]),) + value[1:] | |
return types | |
setattr(cls, "INPUT_TYPES", new_input_types) | |
if hasattr(cls, "RETURN_TYPES"): | |
old_return_types = cls.RETURN_TYPES | |
setattr(cls, "RETURN_TYPES", tuple(MakeSmartType(x) for x in old_return_types)) | |
if hasattr(cls, "VALIDATE_INPUTS"): | |
# Reflection is used to determine what the function signature is, so we can't just change the function signature | |
raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet") | |
else: | |
def validate_inputs(input_types): | |
inputs = cls.INPUT_TYPES() | |
for key, value in input_types.items(): | |
if isinstance(value, SmartType): | |
continue | |
if "required" in inputs and key in inputs["required"]: | |
expected_type = inputs["required"][key][0] | |
elif "optional" in inputs and key in inputs["optional"]: | |
expected_type = inputs["optional"][key][0] | |
else: | |
expected_type = None | |
if expected_type is not None and MakeSmartType(value) != expected_type: | |
return f"Invalid type of {key}: {value} (expected {expected_type})" | |
return True | |
setattr(cls, "VALIDATE_INPUTS", validate_inputs) | |
return cls | |
return decorator | |