multimodalart's picture
Squashing commit
4450790 verified
import json
from ..log import log
def deserialize_curve(curve):
if isinstance(curve, str):
curve = json.loads(curve)
return curve
def serialize_curve(curve):
if not isinstance(curve, str):
curve = json.dumps(curve)
return curve
class MTB_Curve:
"""A basic FLOAT_CURVE input node."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"curve": ("FLOAT_CURVE",),
},
}
RETURN_TYPES = ("FLOAT_CURVE",)
FUNCTION = "do_curve"
CATEGORY = "mtb/curve"
def do_curve(self, curve):
log.debug(f"Curve: {curve}")
return (curve,)
class MTB_CurveToFloat:
"""Convert a FLOAT_CURVE to a FLOAT or FLOATS"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"curve": ("FLOAT_CURVE", {"forceInput": True}),
"steps": ("INT", {"default": 10, "min": 2}),
},
}
RETURN_TYPES = ("FLOATS", "FLOAT")
FUNCTION = "do_curve"
CATEGORY = "mtb/curve"
def do_curve(self, curve, steps):
log.debug(f"Curve: {curve}")
# sort by x (should be handled by the widget)
sorted_points = sorted(curve.items(), key=lambda item: item[1]["x"])
# Extract X and Y values
x_values = [point[1]["x"] for point in sorted_points]
y_values = [point[1]["y"] for point in sorted_points]
# Calculate step size
step_size = (max(x_values) - min(x_values)) / (steps - 1)
# Interpolate Y values for each step
interpolated_y_values = []
for step in range(steps):
current_x = min(x_values) + step_size * step
# Find the indices of the two points between which the current_x falls
idx1 = max(idx for idx, x in enumerate(x_values) if x <= current_x)
idx2 = min(idx for idx, x in enumerate(x_values) if x >= current_x)
# If the current_x matches one of the points, no interpolation is needed
if current_x == x_values[idx1]:
interpolated_y_values.append(y_values[idx1])
elif current_x == x_values[idx2]:
interpolated_y_values.append(y_values[idx2])
else:
# Interpolate Y value using linear interpolation
y1 = y_values[idx1]
y2 = y_values[idx2]
x1 = x_values[idx1]
x2 = x_values[idx2]
interpolated_y = y1 + (y2 - y1) * (current_x - x1) / (x2 - x1)
interpolated_y_values.append(interpolated_y)
return (interpolated_y_values, interpolated_y_values)
__nodes__ = [MTB_Curve, MTB_CurveToFloat]