multimodalart's picture
Squashing commit
4450790 verified
import type {ComfyNodeConstructor, ComfyObjectInfo} from "typings/comfy.js";
import type {INodeSlot, LGraphNode, LLink, LGraphCanvas} from "typings/litegraph.js";
import {app} from "scripts/app.js";
import {DynamicContextNodeBase, InputLike} from "./dynamic_context_base.js";
import {NodeTypesString} from "./constants.js";
import {
InputMutation,
SERVICE as CONTEXT_SERVICE,
stripContextInputPrefixes,
getContextOutputName,
} from "./services/context_service.js";
import {getConnectedInputNodesAndFilterPassThroughs} from "./utils.js";
import {debounce, moveArrayItem} from "rgthree/common/shared_utils.js";
import {measureText} from "./utils_canvas.js";
import {SERVICE as CONFIG_SERVICE} from "./services/config_service.js";
type ShadowInputData = {
node: LGraphNode;
slot: number;
shadowIndex: number;
shadowIndexIfShownSingularly: number;
shadowIndexFull: number;
nodeIndex: number;
type: string | -1;
name: string;
key: string;
// isDuplicatedBefore: boolean,
duplicatesBefore: number[];
duplicatesAfter: number[];
};
/**
* The Context Switch node.
*/
class DynamicContextSwitchNode extends DynamicContextNodeBase {
static override title = NodeTypesString.DYNAMIC_CONTEXT_SWITCH;
static override type = NodeTypesString.DYNAMIC_CONTEXT_SWITCH;
static comfyClass = NodeTypesString.DYNAMIC_CONTEXT_SWITCH;
protected override readonly hasShadowInputs = true;
// override hasShadowInputs = true;
/**
* We should be able to assume that `lastInputsList` is the input list after the last, major
* synchronous change. Which should mean, if we're handling a change that is currently live, but
* not represented in our node (like, an upstream node has already removed an input), then we
* should be able to compar the current InputList to this `lastInputsList`.
*/
lastInputsList: ShadowInputData[] = [];
private shadowInputs: (InputLike & {count: number})[] = [
{name: "base_ctx", type: "RGTHREE_DYNAMIC_CONTEXT", link: null, count: 0},
];
constructor(title = DynamicContextSwitchNode.title) {
super(title);
}
override getContextInputsList() {
return this.shadowInputs;
}
override handleUpstreamMutation(mutation: InputMutation) {
this.scheduleHardRefresh();
}
override onConnectionsChange(
type: number,
slotIndex: number,
isConnected: boolean,
link: LLink,
ioSlot: INodeSlot,
): void {
super.onConnectionsChange?.call(this, type, slotIndex, isConnected, link, ioSlot);
if (this.configuring) {
return;
}
if (type === LiteGraph.INPUT) {
this.scheduleHardRefresh();
}
}
scheduleHardRefresh(ms = 64) {
return debounce(() => {
this.refreshInputsAndOutputs();
}, ms);
}
override onNodeCreated() {
this.addInput("ctx_1", "RGTHREE_DYNAMIC_CONTEXT");
this.addInput("ctx_2", "RGTHREE_DYNAMIC_CONTEXT");
this.addInput("ctx_3", "RGTHREE_DYNAMIC_CONTEXT");
this.addInput("ctx_4", "RGTHREE_DYNAMIC_CONTEXT");
this.addInput("ctx_5", "RGTHREE_DYNAMIC_CONTEXT");
super.onNodeCreated();
}
override addContextInput(name: string, type: string, slot?: number): void {}
/**
* This is a "hard" refresh of the list, but looping over the actual context inputs, and
* recompiling the shadowInputs and outputs.
*/
private refreshInputsAndOutputs() {
const inputs: (InputLike & {count: number})[] = [
{name: "base_ctx", type: "RGTHREE_DYNAMIC_CONTEXT", link: null, count: 0},
];
let numConnected = 0;
for (let i = 0; i < this.inputs.length; i++) {
const childCtxs = getConnectedInputNodesAndFilterPassThroughs(
this,
this,
i,
) as DynamicContextNodeBase[];
if (childCtxs.length > 1) {
throw new Error("How is there more than one input?");
}
const ctx = childCtxs[0];
if (!ctx) continue;
numConnected++;
const slotsData = CONTEXT_SERVICE.getDynamicContextInputsData(ctx);
console.log(slotsData);
for (const slotData of slotsData) {
const found = inputs.find(
(n) => getContextOutputName(slotData.name) === getContextOutputName(n.name),
);
if (found) {
found.count += 1;
continue;
}
inputs.push({
name: slotData.name,
type: slotData.type,
link: null,
count: 1,
});
}
}
this.shadowInputs = inputs;
// First output is always CONTEXT, so "p" is the offset.
let i = 0;
for (i; i < this.shadowInputs.length; i++) {
const data = this.shadowInputs[i]!;
let existing = this.outputs.find(
(o) => getContextOutputName(o.name) === getContextOutputName(data.name),
);
if (!existing) {
existing = this.addOutput(getContextOutputName(data.name), data.type);
}
moveArrayItem(this.outputs, existing, i);
delete existing.rgthree_status;
if (data.count !== numConnected) {
existing.rgthree_status = "WARN";
}
}
while (this.outputs[i]) {
const output = this.outputs[i];
if (output?.links?.length) {
output.rgthree_status = "ERROR";
i++;
} else {
this.removeOutput(i);
}
}
this.fixInputsOutputsLinkSlots();
}
override onDrawForeground(ctx: CanvasRenderingContext2D, canvas: LGraphCanvas): void {
const low_quality = (canvas?.ds?.scale ?? 1) < 0.6;
if (low_quality || this.size[0] <= 10) {
return;
}
let y = LiteGraph.NODE_SLOT_HEIGHT - 1;
const w = this.size[0];
ctx.save();
ctx.font = "normal " + LiteGraph.NODE_SUBTEXT_SIZE + "px Arial";
ctx.textAlign = "right";
for (const output of this.outputs) {
if (!output.rgthree_status) {
y += LiteGraph.NODE_SLOT_HEIGHT;
continue;
}
const x = w - 20 - measureText(ctx, output.name);
if (output.rgthree_status === "ERROR") {
ctx.fillText("🛑", x, y);
} else if (output.rgthree_status === "WARN") {
ctx.fillText("⚠️", x, y);
}
y += LiteGraph.NODE_SLOT_HEIGHT;
}
ctx.restore();
}
}
app.registerExtension({
name: "rgthree.DynamicContextSwitch",
async beforeRegisterNodeDef(nodeType: ComfyNodeConstructor, nodeData: ComfyObjectInfo) {
if (!CONFIG_SERVICE.getConfigValue("unreleased.dynamic_context.enabled")) {
return;
}
if (nodeData.name === DynamicContextSwitchNode.type) {
DynamicContextSwitchNode.setUp(nodeType, nodeData);
}
},
});