Spaces:
Restarting
on
L40S
Restarting
on
L40S
""" | |
This file is part of ComfyUI. | |
Copyright (C) 2024 Comfy | |
This program is free software: you can redistribute it and/or modify | |
it under the terms of the GNU General Public License as published by | |
the Free Software Foundation, either version 3 of the License, or | |
(at your option) any later version. | |
This program is distributed in the hope that it will be useful, | |
but WITHOUT ANY WARRANTY; without even the implied warranty of | |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
GNU General Public License for more details. | |
You should have received a copy of the GNU General Public License | |
along with this program. If not, see <https://www.gnu.org/licenses/>. | |
""" | |
import torch | |
from . import model_base | |
from . import utils | |
from . import latent_formats | |
class ClipTarget: | |
def __init__(self, tokenizer, clip): | |
self.clip = clip | |
self.tokenizer = tokenizer | |
self.params = {} | |
class BASE: | |
unet_config = {} | |
unet_extra_config = { | |
"num_heads": -1, | |
"num_head_channels": 64, | |
} | |
required_keys = {} | |
clip_prefix = [] | |
clip_vision_prefix = None | |
noise_aug_config = None | |
sampling_settings = {} | |
latent_format = latent_formats.LatentFormat | |
vae_key_prefix = ["first_stage_model."] | |
text_encoder_key_prefix = ["cond_stage_model."] | |
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] | |
memory_usage_factor = 2.0 | |
manual_cast_dtype = None | |
custom_operations = None | |
scaled_fp8 = None | |
optimizations = {"fp8": False} | |
def matches(s, unet_config, state_dict=None): | |
for k in s.unet_config: | |
if k not in unet_config or s.unet_config[k] != unet_config[k]: | |
return False | |
if state_dict is not None: | |
for k in s.required_keys: | |
if k not in state_dict: | |
return False | |
return True | |
def model_type(self, state_dict, prefix=""): | |
return model_base.ModelType.EPS | |
def inpaint_model(self): | |
return self.unet_config["in_channels"] > 4 | |
def __init__(self, unet_config): | |
self.unet_config = unet_config.copy() | |
self.sampling_settings = self.sampling_settings.copy() | |
self.latent_format = self.latent_format() | |
self.optimizations = self.optimizations.copy() | |
for x in self.unet_extra_config: | |
self.unet_config[x] = self.unet_extra_config[x] | |
def get_model(self, state_dict, prefix="", device=None): | |
if self.noise_aug_config is not None: | |
out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device) | |
else: | |
out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device) | |
if self.inpaint_model(): | |
out.set_inpaint() | |
return out | |
def process_clip_state_dict(self, state_dict): | |
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) | |
return state_dict | |
def process_unet_state_dict(self, state_dict): | |
return state_dict | |
def process_vae_state_dict(self, state_dict): | |
return state_dict | |
def process_clip_state_dict_for_saving(self, state_dict): | |
replace_prefix = {"": self.text_encoder_key_prefix[0]} | |
return utils.state_dict_prefix_replace(state_dict, replace_prefix) | |
def process_clip_vision_state_dict_for_saving(self, state_dict): | |
replace_prefix = {} | |
if self.clip_vision_prefix is not None: | |
replace_prefix[""] = self.clip_vision_prefix | |
return utils.state_dict_prefix_replace(state_dict, replace_prefix) | |
def process_unet_state_dict_for_saving(self, state_dict): | |
replace_prefix = {"": "model.diffusion_model."} | |
return utils.state_dict_prefix_replace(state_dict, replace_prefix) | |
def process_vae_state_dict_for_saving(self, state_dict): | |
replace_prefix = {"": self.vae_key_prefix[0]} | |
return utils.state_dict_prefix_replace(state_dict, replace_prefix) | |
def set_inference_dtype(self, dtype, manual_cast_dtype): | |
self.unet_config['dtype'] = dtype | |
self.manual_cast_dtype = manual_cast_dtype | |