File size: 2,886 Bytes
028694a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import folder_paths
import comfy.utils
import comfy.model_detection
import comfy.model_management
import comfy.lora
from comfy.model_patcher import ModelPatcher

from .utils import TimestepKeyframeGroup
from .control import ControlNetAdvanced, load_controlnet




def convert_cn_lora_from_diffusers(cn_model: ModelPatcher, lora_path: str):
    lora_data = comfy.utils.load_torch_file(lora_path, safe_load=True)
    unet_dtype = comfy.model_management.unet_dtype()
    for key, value in lora_data.items():
        lora_data[key] = value.to(unet_dtype)
    diffusers_keys = comfy.utils.unet_to_diffusers(cn_model.model.state_dict())

    #lora_data = comfy.model_detection.unet_config_from_diffusers_unet(lora_data, dtype=unet_dtype)



    #key_map = comfy.lora.model_lora_keys_unet(cn_model.model, key_map)
    lora_data = comfy.lora.load_lora(lora_data, to_load=diffusers_keys)

    # TODO: detect if diffusers for sure? not sure if needed at this time, since cn loras are
    # only used currently for LOOSEControl, and those are all in diffusers format
    #unet_dtype = comfy.model_management.unet_dtype()
    #lora_data = comfy.model_detection.unet_config_from_diffusers_unet(lora_data, unet_dtype)
    return lora_data


class ControlNetLoaderWithLoraAdvanced:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "control_net_name": (folder_paths.get_filename_list("controlnet"), ),
                "cn_lora_name": (folder_paths.get_filename_list("controlnet"), ),
                "cn_lora_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
            },
            "optional": {
                "timestep_keyframe": ("TIMESTEP_KEYFRAME", ),
            }
        }

    RETURN_TYPES = ("CONTROL_NET", )
    FUNCTION = "load_controlnet"

    CATEGORY = "Adv-ControlNet ๐Ÿ›‚๐Ÿ…๐Ÿ…’๐Ÿ…/LOOSEControl"

    def load_controlnet(self, control_net_name, cn_lora_name, cn_lora_strength: float,

                        timestep_keyframe: TimestepKeyframeGroup=None

                        ):
        controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
        controlnet: ControlNetAdvanced = load_controlnet(controlnet_path, timestep_keyframe)
        if not isinstance(controlnet, ControlNetAdvanced):
            raise ValueError("Type {} is not compatible with CN LoRA features at this time.")
        # now, try to load CN LoRA
        lora_path = folder_paths.get_full_path("controlnet", cn_lora_name)
        lora_data = convert_cn_lora_from_diffusers(cn_model=controlnet.control_model_wrapped, lora_path=lora_path)
        # apply patches to wrapped control_model
        controlnet.control_model_wrapped.add_patches(lora_data, strength_patch=cn_lora_strength)
        # all done
        return (controlnet,)