multimodalart's picture
Squashing commit
4450790 verified
raw
history blame
13.2 kB
import { app } from "../../scripts/app.js";
import { IoDirection, addConnectionLayoutSupport, addMenuItem, matchLocalSlotsToServer, replaceNode, } from "./utils.js";
import { RgthreeBaseServerNode } from "./base_node.js";
import { SERVICE as KEY_EVENT_SERVICE } from "./services/key_events_services.js";
import { debounce, wait } from "../../rgthree/common/shared_utils.js";
import { removeUnusedInputsFromEnd } from "./utils_inputs_outputs.js";
import { NodeTypesString } from "./constants.js";
function findMatchingIndexByTypeOrName(otherNode, otherSlot, ctxSlots) {
const otherNodeType = (otherNode.type || "").toUpperCase();
const otherNodeName = (otherNode.title || "").toUpperCase();
let otherSlotType = otherSlot.type;
if (Array.isArray(otherSlotType) || otherSlotType.includes(",")) {
otherSlotType = "COMBO";
}
const otherSlotName = otherSlot.name.toUpperCase().replace("OPT_", "").replace("_NAME", "");
let ctxSlotIndex = -1;
if (["CONDITIONING", "INT", "STRING", "FLOAT", "COMBO"].includes(otherSlotType)) {
ctxSlotIndex = ctxSlots.findIndex((ctxSlot) => {
const ctxSlotName = ctxSlot.name.toUpperCase().replace("OPT_", "").replace("_NAME", "");
let ctxSlotType = ctxSlot.type;
if (Array.isArray(ctxSlotType) || ctxSlotType.includes(",")) {
ctxSlotType = "COMBO";
}
if (ctxSlotType !== otherSlotType) {
return false;
}
if (ctxSlotName === otherSlotName ||
(ctxSlotName === "SEED" && otherSlotName.includes("SEED")) ||
(ctxSlotName === "STEP_REFINER" && otherSlotName.includes("AT_STEP")) ||
(ctxSlotName === "STEP_REFINER" && otherSlotName.includes("REFINER_STEP"))) {
return true;
}
if ((otherNodeType.includes("POSITIVE") || otherNodeName.includes("POSITIVE")) &&
((ctxSlotName === "POSITIVE" && otherSlotType === "CONDITIONING") ||
(ctxSlotName === "TEXT_POS_G" && otherSlotName.includes("TEXT_G")) ||
(ctxSlotName === "TEXT_POS_L" && otherSlotName.includes("TEXT_L")))) {
return true;
}
if ((otherNodeType.includes("NEGATIVE") || otherNodeName.includes("NEGATIVE")) &&
((ctxSlotName === "NEGATIVE" && otherSlotType === "CONDITIONING") ||
(ctxSlotName === "TEXT_NEG_G" && otherSlotName.includes("TEXT_G")) ||
(ctxSlotName === "TEXT_NEG_L" && otherSlotName.includes("TEXT_L")))) {
return true;
}
return false;
});
}
else {
ctxSlotIndex = ctxSlots.map((s) => s.type).indexOf(otherSlotType);
}
return ctxSlotIndex;
}
export class BaseContextNode extends RgthreeBaseServerNode {
constructor(title) {
super(title);
this.___collapsed_width = 0;
}
get _collapsed_width() {
return this.___collapsed_width;
}
set _collapsed_width(width) {
const canvas = app.canvas;
const ctx = canvas.canvas.getContext("2d");
const oldFont = ctx.font;
ctx.font = canvas.title_text_font;
let title = this.title.trim();
this.___collapsed_width = 30 + (title ? 10 + ctx.measureText(title).width : 0);
ctx.font = oldFont;
}
connectByType(slot, sourceNode, sourceSlotType, optsIn) {
let canConnect = super.connectByType &&
super.connectByType.call(this, slot, sourceNode, sourceSlotType, optsIn);
if (!super.connectByType) {
canConnect = LGraphNode.prototype.connectByType.call(this, slot, sourceNode, sourceSlotType, optsIn);
}
if (!canConnect && slot === 0) {
const ctrlKey = KEY_EVENT_SERVICE.ctrlKey;
for (const [index, input] of (sourceNode.inputs || []).entries()) {
if (input.link && !ctrlKey) {
continue;
}
const thisOutputSlot = findMatchingIndexByTypeOrName(sourceNode, input, this.outputs);
if (thisOutputSlot > -1) {
this.connect(thisOutputSlot, sourceNode, index);
}
}
}
return null;
}
connectByTypeOutput(slot, sourceNode, sourceSlotType, optsIn) {
var _a;
let canConnect = super.connectByTypeOutput &&
super.connectByTypeOutput.call(this, slot, sourceNode, sourceSlotType, optsIn);
if (!super.connectByType) {
canConnect = LGraphNode.prototype.connectByTypeOutput.call(this, slot, sourceNode, sourceSlotType, optsIn);
}
if (!canConnect && slot === 0) {
const ctrlKey = KEY_EVENT_SERVICE.ctrlKey;
for (const [index, output] of (sourceNode.outputs || []).entries()) {
if (((_a = output.links) === null || _a === void 0 ? void 0 : _a.length) && !ctrlKey) {
continue;
}
const thisInputSlot = findMatchingIndexByTypeOrName(sourceNode, output, this.inputs);
if (thisInputSlot > -1) {
sourceNode.connect(index, this, thisInputSlot);
}
}
}
return null;
}
static setUp(comfyClass, nodeData, ctxClass) {
RgthreeBaseServerNode.registerForOverride(comfyClass, nodeData, ctxClass);
wait(500).then(() => {
LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"] =
LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"] || [];
LiteGraph.slot_types_default_out["RGTHREE_CONTEXT"].push(comfyClass.comfyClass);
});
}
static onRegisteredForOverride(comfyClass, ctxClass) {
addConnectionLayoutSupport(ctxClass, app, [
["Left", "Right"],
["Right", "Left"],
]);
setTimeout(() => {
ctxClass.category = comfyClass.category;
});
}
}
class ContextNode extends BaseContextNode {
constructor(title = ContextNode.title) {
super(title);
}
static setUp(comfyClass, nodeData) {
BaseContextNode.setUp(comfyClass, nodeData, ContextNode);
}
static onRegisteredForOverride(comfyClass, ctxClass) {
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
addMenuItem(ContextNode, app, {
name: "Convert To Context Big",
callback: (node) => {
replaceNode(node, ContextBigNode.type);
},
});
}
}
ContextNode.title = NodeTypesString.CONTEXT;
ContextNode.type = NodeTypesString.CONTEXT;
ContextNode.comfyClass = NodeTypesString.CONTEXT;
class ContextBigNode extends BaseContextNode {
constructor(title = ContextBigNode.title) {
super(title);
}
static setUp(comfyClass, nodeData) {
BaseContextNode.setUp(comfyClass, nodeData, ContextBigNode);
}
static onRegisteredForOverride(comfyClass, ctxClass) {
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
addMenuItem(ContextBigNode, app, {
name: "Convert To Context (Original)",
callback: (node) => {
replaceNode(node, ContextNode.type);
},
});
}
}
ContextBigNode.title = NodeTypesString.CONTEXT_BIG;
ContextBigNode.type = NodeTypesString.CONTEXT_BIG;
ContextBigNode.comfyClass = NodeTypesString.CONTEXT_BIG;
class BaseContextMultiCtxInputNode extends BaseContextNode {
constructor(title) {
super(title);
this.stabilizeBound = this.stabilize.bind(this);
this.addContextInput(5);
}
addContextInput(num = 1) {
for (let i = 0; i < num; i++) {
this.addInput(`ctx_${String(this.inputs.length + 1).padStart(2, "0")}`, "RGTHREE_CONTEXT");
}
}
onConnectionsChange(type, slotIndex, isConnected, link, ioSlot) {
var _a;
(_a = super.onConnectionsChange) === null || _a === void 0 ? void 0 : _a.apply(this, [...arguments]);
if (type === LiteGraph.INPUT) {
this.scheduleStabilize();
}
}
scheduleStabilize(ms = 64) {
return debounce(this.stabilizeBound, 64);
}
stabilize() {
removeUnusedInputsFromEnd(this, 4);
this.addContextInput();
}
}
class ContextSwitchNode extends BaseContextMultiCtxInputNode {
constructor(title = ContextSwitchNode.title) {
super(title);
}
static setUp(comfyClass, nodeData) {
BaseContextNode.setUp(comfyClass, nodeData, ContextSwitchNode);
}
static onRegisteredForOverride(comfyClass, ctxClass) {
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
addMenuItem(ContextSwitchNode, app, {
name: "Convert To Context Switch Big",
callback: (node) => {
replaceNode(node, ContextSwitchBigNode.type);
},
});
}
}
ContextSwitchNode.title = NodeTypesString.CONTEXT_SWITCH;
ContextSwitchNode.type = NodeTypesString.CONTEXT_SWITCH;
ContextSwitchNode.comfyClass = NodeTypesString.CONTEXT_SWITCH;
class ContextSwitchBigNode extends BaseContextMultiCtxInputNode {
constructor(title = ContextSwitchBigNode.title) {
super(title);
}
static setUp(comfyClass, nodeData) {
BaseContextNode.setUp(comfyClass, nodeData, ContextSwitchBigNode);
}
static onRegisteredForOverride(comfyClass, ctxClass) {
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
addMenuItem(ContextSwitchBigNode, app, {
name: "Convert To Context Switch",
callback: (node) => {
replaceNode(node, ContextSwitchNode.type);
},
});
}
}
ContextSwitchBigNode.title = NodeTypesString.CONTEXT_SWITCH_BIG;
ContextSwitchBigNode.type = NodeTypesString.CONTEXT_SWITCH_BIG;
ContextSwitchBigNode.comfyClass = NodeTypesString.CONTEXT_SWITCH_BIG;
class ContextMergeNode extends BaseContextMultiCtxInputNode {
constructor(title = ContextMergeNode.title) {
super(title);
}
static setUp(comfyClass, nodeData) {
BaseContextNode.setUp(comfyClass, nodeData, ContextMergeNode);
}
static onRegisteredForOverride(comfyClass, ctxClass) {
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
addMenuItem(ContextMergeNode, app, {
name: "Convert To Context Merge Big",
callback: (node) => {
replaceNode(node, ContextMergeBigNode.type);
},
});
}
}
ContextMergeNode.title = NodeTypesString.CONTEXT_MERGE;
ContextMergeNode.type = NodeTypesString.CONTEXT_MERGE;
ContextMergeNode.comfyClass = NodeTypesString.CONTEXT_MERGE;
class ContextMergeBigNode extends BaseContextMultiCtxInputNode {
constructor(title = ContextMergeBigNode.title) {
super(title);
}
static setUp(comfyClass, nodeData) {
BaseContextNode.setUp(comfyClass, nodeData, ContextMergeBigNode);
}
static onRegisteredForOverride(comfyClass, ctxClass) {
BaseContextNode.onRegisteredForOverride(comfyClass, ctxClass);
addMenuItem(ContextMergeBigNode, app, {
name: "Convert To Context Switch",
callback: (node) => {
replaceNode(node, ContextMergeNode.type);
},
});
}
}
ContextMergeBigNode.title = NodeTypesString.CONTEXT_MERGE_BIG;
ContextMergeBigNode.type = NodeTypesString.CONTEXT_MERGE_BIG;
ContextMergeBigNode.comfyClass = NodeTypesString.CONTEXT_MERGE_BIG;
const contextNodes = [
ContextNode,
ContextBigNode,
ContextSwitchNode,
ContextSwitchBigNode,
ContextMergeNode,
ContextMergeBigNode,
];
const contextTypeToServerDef = {};
function fixBadConfigs(node) {
const wrongName = node.outputs.find((o, i) => o.name === "CLIP_HEIGTH");
if (wrongName) {
wrongName.name = "CLIP_HEIGHT";
}
}
app.registerExtension({
name: "rgthree.Context",
async beforeRegisterNodeDef(nodeType, nodeData) {
for (const ctxClass of contextNodes) {
if (nodeData.name === ctxClass.type) {
contextTypeToServerDef[ctxClass.type] = nodeData;
ctxClass.setUp(nodeType, nodeData);
break;
}
}
},
async nodeCreated(node) {
const type = node.type || node.constructor.type;
const serverDef = type && contextTypeToServerDef[type];
if (serverDef) {
fixBadConfigs(node);
matchLocalSlotsToServer(node, IoDirection.OUTPUT, serverDef);
if (!type.includes("Switch") && !type.includes("Merge")) {
matchLocalSlotsToServer(node, IoDirection.INPUT, serverDef);
}
}
},
async loadedGraphNode(node) {
const type = node.type || node.constructor.type;
const serverDef = type && contextTypeToServerDef[type];
if (serverDef) {
fixBadConfigs(node);
matchLocalSlotsToServer(node, IoDirection.OUTPUT, serverDef);
if (!type.includes("Switch") && !type.includes("Merge")) {
matchLocalSlotsToServer(node, IoDirection.INPUT, serverDef);
}
}
},
});