File size: 6,717 Bytes
43b7e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from typing import Any, Dict, List
from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME
class PipelineCallback(ConfigMixin):
"""
Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
custom callbacks and ensures that all callbacks have a consistent interface.
Please implement the following:
`tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
include
variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
`callback_fn`: This method defines the core functionality of your callback.
"""
config_name = CONFIG_NAME
@register_to_config
def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
super().__init__()
if (cutoff_step_ratio is None and cutoff_step_index is None) or (
cutoff_step_ratio is not None and cutoff_step_index is not None
):
raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")
if cutoff_step_ratio is not None and (
not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
):
raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
@property
def tensor_inputs(self) -> List[str]:
raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
class MultiPipelineCallbacks:
"""
This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
provides a unified interface for calling all of them.
"""
def __init__(self, callbacks: List[PipelineCallback]):
self.callbacks = callbacks
@property
def tensor_inputs(self) -> List[str]:
return [input for callback in self.callbacks for input in callback.tensor_inputs]
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
"""
Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
"""
for callback in self.callbacks:
callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)
return callback_kwargs
class SDCFGCutoffCallback(PipelineCallback):
"""
Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
tensor_inputs = ["prompt_embeds"]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)
if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
pipeline._guidance_scale = 0.0
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
return callback_kwargs
class SDXLCFGCutoffCallback(PipelineCallback):
"""
Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)
if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
add_time_ids = callback_kwargs[self.tensor_inputs[2]]
add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
pipeline._guidance_scale = 0.0
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
return callback_kwargs
class IPAdapterScaleCutoffCallback(PipelineCallback):
"""
Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
`cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.
Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
"""
tensor_inputs = []
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)
if step_index == cutoff_step:
pipeline.set_ip_adapter_scale(0.0)
return callback_kwargs
|